From ac82bfbd4fbf092ba762e54a3e6207bf2fdf84cf Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Fri, 3 Apr 2026 10:46:21 +0300 Subject: [PATCH 01/39] feat(dovecot-charm): add TLS certificate integration via certificates relation --- dovecot-charm/charmcraft.yaml | 8 + .../v4/tls_certificates.py | 2526 +++++++++++++++++ dovecot-charm/pyproject.toml | 8 + dovecot-charm/src/charm.py | 56 + .../v4/tls_certificates.py | 2525 ++++++++++++++++ dovecot-charm/templates/dovecot.conf.tmpl | 20 +- dovecot-charm/tests/integration/test_tls.py | 123 + dovecot-charm/tests/unit/test_charm.py | 53 + dovecot-charm/uv.lock | 121 + 9 files changed, 5430 insertions(+), 10 deletions(-) create mode 100644 dovecot-charm/lib/charms/tls_certificates_interface/v4/tls_certificates.py create mode 100644 dovecot-charm/src/charms/tls_certificates_interface/v4/tls_certificates.py create mode 100644 dovecot-charm/tests/integration/test_tls.py diff --git a/dovecot-charm/charmcraft.yaml b/dovecot-charm/charmcraft.yaml index 5fc2869..de95635 100644 --- a/dovecot-charm/charmcraft.yaml +++ b/dovecot-charm/charmcraft.yaml @@ -7,6 +7,10 @@ base: ubuntu@24.04 platforms: amd64: +charm-libs: + - lib: tls-certificates-interface.tls_certificates + version: "4" + name: dovecot-charm summary: Dovecot IMAP/POP3 mail server charm @@ -39,6 +43,10 @@ peers: replicas: interface: replicas +requires: + certificates: + interface: tls-certificates + config: options: mailname: diff --git a/dovecot-charm/lib/charms/tls_certificates_interface/v4/tls_certificates.py b/dovecot-charm/lib/charms/tls_certificates_interface/v4/tls_certificates.py new file mode 100644 index 0000000..32b3b15 --- /dev/null +++ b/dovecot-charm/lib/charms/tls_certificates_interface/v4/tls_certificates.py @@ -0,0 +1,2526 @@ +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Legacy Charmhub-hosted lib, deprecated in favour of ``charmlibs.interfaces.tls_certificates``. + +WARNING: This library is deprecated. +It will not receive feature updates or bugfixes. +``charmlibs.interfaces.tls_certificates`` 1.0 is a bug-for-bug compatible migration of this library. + +To migrate: +1. Add 'charmlibs-interfaces-tls-certificates~=1.0' to your charm's dependencies, + and remove this Charmhub-hosted library from your charm. +2. You can also remove any dependencies added to your charm only because of this library. +3. Replace `from charms.tls_certificates_interface.v4 import tls_certificates` + with `from charmlibs.interfaces import tls_certificates`. + +Read more: +- https://documentation.ubuntu.com/charmlibs +- https://pypi.org/project/charmlibs-interfaces-tls-certificates + +--- + +Charm library for managing TLS certificates (V4). + +This library contains the Requires and Provides classes for handling the tls-certificates +interface. + +Pre-requisites: + - Juju >= 3.0 + - cryptography >= 43.0.0 + - pydantic >= 1.0 + +Learn more on how-to use the TLS Certificates interface library by reading the documentation: +- https://charmhub.io/tls-certificates-interface/ + +""" # noqa: D214, D405, D411, D416 + +import copy +import ipaddress +import json +import logging +import uuid +import warnings +from contextlib import suppress +from dataclasses import asdict, dataclass, field +from datetime import datetime, timedelta, timezone +from enum import Enum +from typing import ( + Collection, + Dict, + FrozenSet, + List, + MutableMapping, + Optional, + Set, + Tuple, + Union, +) + +import pydantic +from cryptography import x509 +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.types import CertificateIssuerPrivateKeyTypes +from cryptography.x509.oid import ExtensionOID, NameOID +from ops import BoundEvent, CharmBase, CharmEvents, Secret, SecretExpiredEvent, SecretRemoveEvent +from ops.framework import EventBase, EventSource, Handle, Object +from ops.jujuversion import JujuVersion +from ops.model import Application, ModelError, Relation, SecretNotFoundError, Unit + +# The unique Charmhub library identifier, never change it +LIBID = "afd8c2bccf834997afce12c2706d2ede" + +# Increment this major API version when introducing breaking changes +LIBAPI = 4 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 27 + +PYDEPS = [ + "cryptography>=43.0.0", + "pydantic", +] +IS_PYDANTIC_V1 = int(pydantic.version.VERSION.split(".")[0]) < 2 + +logger = logging.getLogger(__name__) + +NESTED_JSON_KEY = "owasp_event" + + +@dataclass +class _OWASPLogEvent: + """OWASP-compliant log event.""" + + datetime: str + event: str + level: str + description: str + type: str = "security" + labels: Dict[str, str] = field(default_factory=dict) + + def to_json(self) -> str: + return json.dumps(self.to_dict(), ensure_ascii=False) + + def to_dict(self) -> Dict: + log_event = dict(asdict(self), **self.labels) + log_event.pop("labels", None) + return {k: v for k, v in log_event.items() if v is not None} + + +class _OWASPLogger: + """OWASP-compliant logger for security events.""" + + def __init__(self, application: Optional[str] = None): + self.application = application + self._logger = logging.getLogger(__name__) + + def log_event(self, event: str, level: int, description: str, **labels: str): + if self.application and "application" not in labels: + labels["application"] = self.application + log = _OWASPLogEvent( + datetime=datetime.now(timezone.utc).astimezone().isoformat(), + event=event, + level=logging.getLevelName(level), + description=description, + labels=labels, + ) + self._logger.log(level, log.to_json(), extra={NESTED_JSON_KEY: log.to_dict()}) + + +class TLSCertificatesError(Exception): + """Base class for custom errors raised by this library.""" + + +class DataValidationError(TLSCertificatesError): + """Raised when data validation fails.""" + + +class _DatabagModel(pydantic.BaseModel): + """Base databag model. + + Supports both pydantic v1 and v2. + """ + + if IS_PYDANTIC_V1: + + class Config: + """Pydantic config.""" + + # ignore any extra fields in the databag + extra = "ignore" + """Ignore any extra fields in the databag.""" + allow_population_by_field_name = True + """Allow instantiating this class by field name (instead of forcing alias).""" + + _NEST_UNDER = None + + model_config = pydantic.ConfigDict( + # tolerate additional keys in databag + extra="ignore", + # Allow instantiating this class by field name (instead of forcing alias). + populate_by_name=True, + # Custom config key: whether to nest the whole datastructure (as json) + # under a field or spread it out at the toplevel. + _NEST_UNDER=None, + ) # type: ignore + """Pydantic config.""" + + @classmethod + def load(cls, databag: MutableMapping): + """Load this model from a Juju databag.""" + if IS_PYDANTIC_V1: + return cls._load_v1(databag) + nest_under = cls.model_config.get("_NEST_UNDER") + if nest_under: + return cls.model_validate(json.loads(databag[nest_under])) + + try: + data = { + k: json.loads(v) + for k, v in databag.items() + # Don't attempt to parse model-external values + if k in {(f.alias or n) for n, f in cls.model_fields.items()} + } + except json.JSONDecodeError as e: + msg = f"invalid databag contents: expecting json. {databag}" + logger.error(msg) + raise DataValidationError(msg) from e + + try: + return cls.model_validate_json(json.dumps(data)) + except pydantic.ValidationError as e: + msg = f"failed to validate databag: {databag}" + logger.debug(msg, exc_info=True) + raise DataValidationError(msg) from e + + @classmethod + def _load_v1(cls, databag: MutableMapping): + """Load implementation for pydantic v1.""" + if cls._NEST_UNDER: + return cls.parse_obj(json.loads(databag[cls._NEST_UNDER])) + + try: + data = { + k: json.loads(v) + for k, v in databag.items() + # Don't attempt to parse model-external values + if k in {f.alias for f in cls.__fields__.values()} + } + except json.JSONDecodeError as e: + msg = f"invalid databag contents: expecting json. {databag}" + logger.error(msg) + raise DataValidationError(msg) from e + + try: + return cls.parse_raw(json.dumps(data)) # type: ignore + except pydantic.ValidationError as e: + msg = f"failed to validate databag: {databag}" + logger.debug(msg, exc_info=True) + raise DataValidationError(msg) from e + + def dump(self, databag: Optional[MutableMapping] = None, clear: bool = True): + """Write the contents of this model to Juju databag. + + Args: + databag: The databag to write to. + clear: Whether to clear the databag before writing. + + Returns: + MutableMapping: The databag. + """ + if IS_PYDANTIC_V1: + return self._dump_v1(databag, clear) + if clear and databag: + databag.clear() + + if databag is None: + databag = {} + nest_under = self.model_config.get("_NEST_UNDER") + if nest_under: + databag[nest_under] = self.model_dump_json( + by_alias=True, + # skip keys whose values are default + exclude_defaults=True, + ) + return databag + + dct = self.model_dump(mode="json", by_alias=True, exclude_defaults=True) + databag.update({k: json.dumps(v) for k, v in dct.items()}) + return databag + + def _dump_v1(self, databag: Optional[MutableMapping] = None, clear: bool = True): + """Dump implementation for pydantic v1.""" + if clear and databag: + databag.clear() + + if databag is None: + databag = {} + + if self._NEST_UNDER: + databag[self._NEST_UNDER] = self.json(by_alias=True, exclude_defaults=True) + return databag + + dct = json.loads(self.json(by_alias=True, exclude_defaults=True)) + databag.update({k: json.dumps(v) for k, v in dct.items()}) + + return databag + + +class _Certificate(pydantic.BaseModel): + """Certificate model.""" + + ca: str + certificate_signing_request: str + certificate: str + chain: Optional[List[str]] = None + revoked: Optional[bool] = None + + def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate": + """Convert to a ProviderCertificate.""" + return ProviderCertificate( + relation_id=relation_id, + certificate=Certificate.from_string(self.certificate), + certificate_signing_request=CertificateSigningRequest.from_string( + self.certificate_signing_request + ), + ca=Certificate.from_string(self.ca), + chain=[Certificate.from_string(certificate) for certificate in self.chain] + if self.chain + else [], + revoked=self.revoked, + ) + + +class _CertificateSigningRequest(pydantic.BaseModel): + """Certificate signing request model.""" + + certificate_signing_request: str + ca: Optional[bool] + + +class _ProviderApplicationData(_DatabagModel): + """Provider application data model.""" + + certificates: List[_Certificate] = [] + + +class _RequirerData(_DatabagModel): + """Requirer data model. + + The same model is used for the unit and application data. + """ + + certificate_signing_requests: List[_CertificateSigningRequest] = [] + + +class Mode(Enum): + """Enum representing the mode of the certificate request. + + UNIT (default): Request a certificate for the unit. + Each unit will manage its private key, + certificate signing request and certificate. + APP: Request a certificate for the application. + Only the leader unit will manage the private key, certificate signing request + and certificate. + """ + + UNIT = 1 + APP = 2 + + +class PrivateKey: + """This class represents a private key.""" + + def __init__( + self, raw: Optional[str] = None, x509_object: Optional[rsa.RSAPrivateKey] = None + ) -> None: + """Initialize the PrivateKey object. + + If both raw and x509_object are provided, x509_object takes precedence. + """ + if x509_object: + self._private_key = x509_object + elif raw: + self._private_key = serialization.load_pem_private_key( + raw.encode(), + password=None, + ) + else: + raise ValueError("Either raw private key string or x509_object must be provided") + + @property + def raw(self) -> str: + """Return the PEM-formatted string representation of the private key.""" + return str(self) + + def __str__(self): + """Return the private key as a string in PEM format.""" + return ( + self._private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + .decode() + .strip() + ) + + def __hash__(self): + """Return the hash of the private key.""" + return hash(self.raw) + + @classmethod + def from_string(cls, private_key: str) -> "PrivateKey": + """Create a PrivateKey object from a private key.""" + return cls(raw=private_key) + + def is_valid(self) -> bool: + """Validate that the private key is PEM-formatted, RSA, and at least 2048 bits.""" + try: + if not isinstance(self._private_key, rsa.RSAPrivateKey): + logger.warning("Private key is not an RSA key") + return False + + if self._private_key.key_size < 2048: + logger.warning("RSA key size is less than 2048 bits") + return False + + return True + except ValueError: + logger.warning("Invalid private key format") + return False + + @classmethod + def generate(cls, key_size: int = 2048, public_exponent: int = 65537) -> "PrivateKey": + """Generate a new RSA private key. + + Args: + key_size: The size of the key in bits. + public_exponent: The public exponent of the key. + + Returns: + PrivateKey: The generated private key. + """ + private_key = rsa.generate_private_key( + public_exponent=public_exponent, + key_size=key_size, + ) + _OWASPLogger().log_event( + event="private_key_generated", + level=logging.INFO, + description="Private key generated", + key_size=str(key_size), + ) + return PrivateKey(x509_object=private_key) + + def __eq__(self, other: object) -> bool: + """Check if two PrivateKey objects are equal.""" + if not isinstance(other, PrivateKey): + return NotImplemented + return self.raw == other.raw + + +class Certificate: + """This class represents a certificate.""" + + _cert: x509.Certificate + + def __init__( + self, + raw: Optional[str] = None, # Must remain first argument for backwards compatibility + # Old Interface fields (ignored) + common_name: Optional[str] = None, + expiry_time: Optional[datetime] = None, + validity_start_time: Optional[datetime] = None, + is_ca: Optional[bool] = None, + sans_dns: Optional[Set[str]] = None, + sans_ip: Optional[Set[str]] = None, + sans_oid: Optional[Set[str]] = None, + email_address: Optional[str] = None, + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, + # End Old Interface fields + x509_object: Optional[x509.Certificate] = None, + ) -> None: + """Initialize the Certificate object. + + This initializer must maintain the old interface while also allowing + instantiation from an existing x509_object. It ignores all fields + other than raw and x509_object, preferring x509_object. + """ + if x509_object: + self._cert = x509_object + elif raw: + self._cert = x509.load_pem_x509_certificate(data=raw.encode()) + else: + raise ValueError("Either raw certificate string or x509_object must be provided") + + @property + def raw(self) -> str: + """Return the PEM-formatted string representation of the certificate.""" + return str(self) + + @property + def common_name(self) -> str: + """Return the common name of the certificate.""" + # We maintain compatibility with the old interface by returning + # an empty string if no common name is set. + common_name = self._cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME) + return str(common_name[0].value) if common_name else "" + + @property + def expiry_time(self) -> datetime: + """Return the expiry time of the certificate.""" + return self._cert.not_valid_after_utc + + @property + def validity_start_time(self) -> datetime: + """Return the validity start time of the certificate.""" + return self._cert.not_valid_before_utc + + @property + def is_ca(self) -> bool: + """Return whether the certificate is a CA certificate.""" + try: + return self._cert.extensions.get_extension_for_oid( + ExtensionOID.BASIC_CONSTRAINTS + ).value.ca # type: ignore[reportAttributeAccessIssue] + except x509.ExtensionNotFound: + return False + + @property + def sans_dns(self) -> Optional[Set[str]]: + """Return the DNS Subject Alternative Names of the certificate.""" + with suppress(x509.ExtensionNotFound): + sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san) for san in sans.get_values_for_type(x509.DNSName)} + return None + + @property + def sans_ip(self) -> Optional[Set[str]]: + """Return the IP Subject Alternative Names of the certificate.""" + with suppress(x509.ExtensionNotFound): + sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san) for san in sans.get_values_for_type(x509.IPAddress)} + return None + + @property + def sans_oid(self) -> Optional[Set[str]]: + """Return the OID Subject Alternative Names of the certificate.""" + with suppress(x509.ExtensionNotFound): + sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san.dotted_string) for san in sans.get_values_for_type(x509.RegisteredID)} + return None + + @property + def email_address(self) -> Optional[str]: + """Return the email address of the certificate.""" + email_address = self._cert.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) + return str(email_address[0].value) if email_address else None + + @property + def organization(self) -> Optional[str]: + """Return the organization name of the certificate.""" + organization = self._cert.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) + return str(organization[0].value) if organization else None + + @property + def organizational_unit(self) -> Optional[str]: + """Return the organizational unit name of the certificate.""" + organizational_unit = self._cert.subject.get_attributes_for_oid( + NameOID.ORGANIZATIONAL_UNIT_NAME + ) + return str(organizational_unit[0].value) if organizational_unit else None + + @property + def country_name(self) -> Optional[str]: + """Return the country name of the certificate.""" + country_name = self._cert.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) + return str(country_name[0].value) if country_name else None + + @property + def state_or_province_name(self) -> Optional[str]: + """Return the state or province name of the certificate.""" + state_or_province_name = self._cert.subject.get_attributes_for_oid( + NameOID.STATE_OR_PROVINCE_NAME + ) + return str(state_or_province_name[0].value) if state_or_province_name else None + + @property + def locality_name(self) -> Optional[str]: + """Return the locality name of the certificate.""" + locality_name = self._cert.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) + return str(locality_name[0].value) if locality_name else None + + def __str__(self) -> str: + """Return the certificate as a string.""" + return self._cert.public_bytes(serialization.Encoding.PEM).decode().strip() + + def __eq__(self, other: object) -> bool: + """Check if two Certificate objects are equal.""" + if not isinstance(other, Certificate): + return NotImplemented + return self.raw == other.raw + + @classmethod + def from_string(cls, certificate: str) -> "Certificate": + """Create a Certificate object from a certificate.""" + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + except ValueError as e: + logger.error("Could not load certificate: %s", e) + raise TLSCertificatesError("Could not load certificate") + + return cls(x509_object=certificate_object) + + def matches_private_key(self, private_key: PrivateKey) -> bool: + """Check if this certificate matches a given private key. + + Args: + private_key (PrivateKey): The private key to validate against. + + Returns: + bool: True if the certificate matches the private key, False otherwise. + """ + try: + cert_public_key = self._cert.public_key() + key_public_key = private_key._private_key.public_key() + + if not isinstance(cert_public_key, rsa.RSAPublicKey): + logger.warning("Certificate does not use RSA public key") + return False + + if not isinstance(key_public_key, rsa.RSAPublicKey): + logger.warning("Private key is not an RSA key") + return False + + return cert_public_key.public_numbers() == key_public_key.public_numbers() + except Exception as e: + logger.warning("Failed to validate certificate and private key match: %s", e) + return False + + @classmethod + def generate( + cls, + csr: "CertificateSigningRequest", + ca: "Certificate", + ca_private_key: "PrivateKey", + validity: timedelta, + is_ca: bool = False, + ) -> "Certificate": + """Generate a certificate from a CSR signed by the given CA and CA private key. + + Args: + csr: The certificate signing request. + ca: The CA certificate. + ca_private_key: The CA private key. + validity: The validity period of the certificate. + is_ca: Whether the generated certificate is a CA certificate. + + Returns: + Certificate: The generated certificate. + """ + # Ideally, this would be the constructor, but we can't add new + # required parameters to the constructor without breaking backwards + # compatibility. + private_key = serialization.load_pem_private_key( + str(ca_private_key).encode(), password=None + ) + assert isinstance(private_key, CertificateIssuerPrivateKeyTypes) + + # Create a certificate builder + cert_builder = x509.CertificateBuilder( + subject_name=csr._csr.subject, + # issuer_name=ca._cert.subject, # TODO: Validate this is correct, the old code used `issuer` + issuer_name=ca._cert.issuer, + public_key=csr._csr.public_key(), + serial_number=x509.random_serial_number(), + not_valid_before=datetime.now(timezone.utc), + not_valid_after=datetime.now(timezone.utc) + validity, + ) + extensions = _generate_certificate_request_extensions( + authority_key_identifier=ca._cert.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier + ).value.key_identifier, + csr=csr._csr, + is_ca=is_ca, + ) + for extension in extensions: + try: + cert_builder = cert_builder.add_extension(extension.value, extension.critical) + except ValueError as e: + logger.error("Could not add extension to certificate: %s", e) + raise TLSCertificatesError("Could not add extension to certificate") from e + + # Sign the certificate with the CA's private key + cert = cert_builder.sign(private_key=private_key, algorithm=hashes.SHA256()) + _OWASPLogger().log_event( + event="certificate_generated", + level=logging.INFO, + description="Certificate generated from CSR", + common_name=csr.common_name, + is_ca=str(is_ca), + validity_days=str(validity.days), + ) + + return cls(x509_object=cert) + + @classmethod + def generate_self_signed_ca( + cls, + attributes: "CertificateRequestAttributes", + private_key: PrivateKey, + validity: timedelta, + ) -> "Certificate": + """Generate a self-signed CA certificate. + + Args: + attributes: The certificate request attributes. + private_key: The private key to sign the CA certificate. + validity: The validity period of the CA certificate. + + Returns: + Certificate: The generated CA certificate. + """ + assert isinstance(private_key._private_key, rsa.RSAPrivateKey) + + public_key = private_key._private_key.public_key() + + builder = x509.CertificateBuilder( + public_key=public_key, + serial_number=x509.random_serial_number(), + not_valid_before=datetime.now(timezone.utc), + not_valid_after=datetime.now(timezone.utc) + validity, + ) + + if subject_name := _extract_subject_name_attributes(attributes): + builder = builder.subject_name(subject_name).issuer_name(subject_name) + + builder = ( + builder.add_extension( + x509.SubjectKeyIdentifier.from_public_key(public_key), critical=False + ) + .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True) + .add_extension( + x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + key_cert_sign=True, + key_agreement=False, + content_commitment=False, + data_encipherment=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + ) + + if san_extension := _san_extension( + email_address=attributes.email_address, + sans_dns=attributes.sans_dns, + sans_ip=attributes.sans_ip, + sans_oid=attributes.sans_oid, + ): + builder = builder.add_extension(san_extension, critical=False) + + cert = cls(x509_object=builder.sign(private_key._private_key, algorithm=hashes.SHA256())) + + _OWASPLogger().log_event( + event="ca_certificate_generated", + level=logging.INFO, + description="CA certificate generated", + common_name=cert.common_name, + validity_days=str(validity.days), + ) + + return cert + + def __hash__(self): + """Return the hash of the private key.""" + return hash(self.raw) + + +class CertificateSigningRequest: + """A representation of the certificate signing request.""" + + _csr: x509.CertificateSigningRequest + + def __init__( + self, + raw: Optional[str] = None, # Must remain first argument for backwards compatibility + # Old Interface fields (ignored) + common_name: Optional[str] = None, + sans_dns: Optional[Set[str]] = None, + sans_ip: Optional[Set[str]] = None, + sans_oid: Optional[Set[str]] = None, + email_address: Optional[str] = None, + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, + has_unique_identifier: Optional[bool] = None, + # End Old Interface fields + x509_object: Optional[x509.CertificateSigningRequest] = None, + ): + """Initialize the CertificateSigningRequest object. + + This initializer must maintain the old interface while also allowing + instantiation from an existing x509_object. It ignores all fields + other than raw and x509_object, preferring x509_object. + """ + if x509_object: + self._csr = x509_object + return + elif raw: + try: + self._csr = x509.load_pem_x509_csr(raw.encode()) + except ValueError as e: + logger.error("Could not load CSR: %s", e) + raise TLSCertificatesError("Could not load CSR") + return + raise ValueError("Either raw CSR string or x509_object must be provided") + + @property + def common_name(self) -> str: + """Return the common name of the CSR.""" + common_name = self._csr.subject.get_attributes_for_oid(NameOID.COMMON_NAME) + return str(common_name[0].value) if common_name else "" + + @property + def sans_dns(self) -> Set[str]: + """Return the DNS Subject Alternative Names of the CSR.""" + with suppress(x509.ExtensionNotFound): + sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san) for san in sans.get_values_for_type(x509.DNSName)} + return set() + + @property + def sans_ip(self) -> Set[str]: + """Return the IP Subject Alternative Names of the CSR.""" + with suppress(x509.ExtensionNotFound): + sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san) for san in sans.get_values_for_type(x509.IPAddress)} + return set() + + @property + def sans_oid(self) -> Set[str]: + """Return the OID Subject Alternative Names of the CSR.""" + with suppress(x509.ExtensionNotFound): + sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san.dotted_string) for san in sans.get_values_for_type(x509.RegisteredID)} + return set() + + @property + def email_address(self) -> Optional[str]: + """Return the email address of the CSR.""" + email_address = self._csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) + return str(email_address[0].value) if email_address else None + + @property + def organization(self) -> Optional[str]: + """Return the organization name of the CSR.""" + organization = self._csr.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) + return str(organization[0].value) if organization else None + + @property + def organizational_unit(self) -> Optional[str]: + """Return the organizational unit name of the CSR.""" + organizational_unit = self._csr.subject.get_attributes_for_oid( + NameOID.ORGANIZATIONAL_UNIT_NAME + ) + return str(organizational_unit[0].value) if organizational_unit else None + + @property + def country_name(self) -> Optional[str]: + """Return the country name of the CSR.""" + country_name = self._csr.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) + return str(country_name[0].value) if country_name else None + + @property + def state_or_province_name(self) -> Optional[str]: + """Return the state or province name of the CSR.""" + state_or_province_name = self._csr.subject.get_attributes_for_oid( + NameOID.STATE_OR_PROVINCE_NAME + ) + return str(state_or_province_name[0].value) if state_or_province_name else None + + @property + def locality_name(self) -> Optional[str]: + """Return the locality name of the CSR.""" + locality_name = self._csr.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) + return str(locality_name[0].value) if locality_name else None + + @property + def has_unique_identifier(self) -> bool: + """Return whether the CSR has a unique identifier.""" + unique_identifier = self._csr.subject.get_attributes_for_oid( + NameOID.X500_UNIQUE_IDENTIFIER + ) + return bool(unique_identifier) + + @property + def raw(self) -> str: + """Return the PEM-formatted string representation of the CSR.""" + return self.__str__() + + def __str__(self) -> str: + """Return the CSR as a string.""" + return self._csr.public_bytes(serialization.Encoding.PEM).decode().strip() + + @property + def additional_critical_extensions(self) -> List[x509.ExtensionType]: + """Return additional critical extensions present on the CSR (excluding SAN).""" + extensions: List[x509.ExtensionType] = [] + for extension in self._csr.extensions: + if extension.critical and extension.oid != ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + extensions.append(extension.value) + return extensions + + @classmethod + def from_string(cls, csr: str) -> "CertificateSigningRequest": + """Create a CertificateSigningRequest object from a CSR.""" + return cls(raw=csr) + + @classmethod + def from_csr(cls, csr: x509.CertificateSigningRequest) -> "CertificateSigningRequest": + """Create a CertificateSigningRequest object from a CSR.""" + return cls(x509_object=csr) + + def __eq__(self, other: object) -> bool: + """Check if two CertificateSigningRequest objects are equal.""" + if not isinstance(other, CertificateSigningRequest): + return NotImplemented + return self.raw == other.raw + + def __hash__(self): + """Return the hash of the private key.""" + return hash(self.raw) + + def matches_certificate(self, certificate: Certificate) -> bool: + """Check if this CSR matches a given certificate. + + Args: + certificate (Certificate): The certificate to validate against. + + Returns: + bool: True if the CSR matches the certificate, False otherwise. + """ + return self._csr.public_key() == certificate._cert.public_key() + + def matches_private_key(self, key: PrivateKey) -> bool: + """Check if a CSR matches a private key. + + This function only works with RSA keys. + + Args: + key (PrivateKey): Private key + Returns: + bool: True/False depending on whether the CSR matches the private key. + """ + try: + key_object_public_key = key._private_key.public_key() + csr_object_public_key = self._csr.public_key() + if not isinstance(key_object_public_key, rsa.RSAPublicKey): + logger.warning("Key is not an RSA key") + return False + if not isinstance(csr_object_public_key, rsa.RSAPublicKey): + logger.warning("CSR is not an RSA key") + return False + if ( + csr_object_public_key.public_numbers().n + != key_object_public_key.public_numbers().n + ): + logger.warning("Public key numbers between CSR and key do not match") + return False + except ValueError: + logger.warning("Could not load certificate or CSR.") + return False + return True + + def get_sha256_hex(self) -> str: + """Calculate the hash of the provided data and return the hexadecimal representation.""" + digest = hashes.Hash(hashes.SHA256()) + digest.update(self.raw.encode()) + return digest.finalize().hex() + + def sign( + self, ca: Certificate, ca_private_key: PrivateKey, validity: timedelta, is_ca: bool = False + ) -> Certificate: + """Sign this CSR with the given CA and CA private key. + + Args: + ca: The CA certificate. + ca_private_key: The CA private key. + validity: The validity period of the certificate. + is_ca: Whether the generated certificate is a CA certificate. + + Returns: + Certificate: The signed certificate. + """ + return Certificate.generate( + csr=self, + ca=ca, + ca_private_key=ca_private_key, + validity=validity, + is_ca=is_ca, + ) + + @classmethod + def generate( + cls, + attributes: "CertificateRequestAttributes", + private_key: PrivateKey, + ) -> "CertificateSigningRequest": + """Generate a CSR using the supplied attributes and private key. + + Args: + attributes (CertificateRequestAttributes): Certificate request attributes + private_key (PrivateKey): Private key + Returns: + CertificateSigningRequest: CSR + """ + signing_key = private_key._private_key + assert isinstance(signing_key, CertificateIssuerPrivateKeyTypes) + + csr_builder = x509.CertificateSigningRequestBuilder() + if subject_name := _extract_subject_name_attributes(attributes): + csr_builder = csr_builder.subject_name(subject_name) + + _sans: List[x509.GeneralName] = [] + if attributes.sans_oid: + _sans.extend( + [x509.RegisteredID(x509.ObjectIdentifier(san)) for san in attributes.sans_oid] + ) + if attributes.sans_ip: + _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in attributes.sans_ip]) + if attributes.sans_dns: + _sans.extend([x509.DNSName(san) for san in attributes.sans_dns]) + if _sans: + csr_builder = csr_builder.add_extension( + x509.SubjectAlternativeName(set(_sans)), critical=False + ) + if attributes.additional_critical_extensions: + for extension in attributes.additional_critical_extensions: + csr_builder = csr_builder.add_extension(extension, critical=True) + signed_certificate_request = csr_builder.sign(signing_key, hashes.SHA256()) + return cls(x509_object=signed_certificate_request) + + +class CertificateRequestAttributes: + """A representation of the certificate request attributes.""" + + def __init__( + self, + common_name: Optional[str] = None, + sans_dns: Optional[Collection[str]] = None, + sans_ip: Optional[Collection[str]] = None, + sans_oid: Optional[Collection[str]] = None, + email_address: Optional[str] = None, + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, + is_ca: bool = False, + add_unique_id_to_subject_name: bool = True, + additional_critical_extensions: Optional[Collection[x509.ExtensionType]] = None, + ): + if not common_name and not sans_dns and not sans_ip and not sans_oid: + raise ValueError( + "At least one of common_name, sans_dns, sans_ip, or sans_oid must be provided" + ) + self._common_name = common_name + self._sans_dns = set(sans_dns) if sans_dns else None + self._sans_ip = set(sans_ip) if sans_ip else None + self._sans_oid = set(sans_oid) if sans_oid else None + self._email_address = email_address + self._organization = organization + self._organizational_unit = organizational_unit + self._country_name = country_name + self._state_or_province_name = state_or_province_name + self._locality_name = locality_name + self._is_ca = is_ca + self._add_unique_id_to_subject_name = add_unique_id_to_subject_name + self._additional_critical_extensions = list(additional_critical_extensions or []) + + @property + def common_name(self) -> str: + """Return the common name.""" + # For legacy interface compatibility, return empty string if not set + return self._common_name if self._common_name else "" + + @property + def sans_dns(self) -> Optional[Set[str]]: + """Return the DNS Subject Alternative Names.""" + return self._sans_dns + + @property + def sans_ip(self) -> Optional[Set[str]]: + """Return the IP Subject Alternative Names.""" + return self._sans_ip + + @property + def sans_oid(self) -> Optional[Set[str]]: + """Return the OID Subject Alternative Names.""" + return self._sans_oid + + @property + def email_address(self) -> Optional[str]: + """Return the email address.""" + return self._email_address + + @property + def organization(self) -> Optional[str]: + """Return the organization name.""" + return self._organization + + @property + def organizational_unit(self) -> Optional[str]: + """Return the organizational unit name.""" + return self._organizational_unit + + @property + def country_name(self) -> Optional[str]: + """Return the country name.""" + return self._country_name + + @property + def state_or_province_name(self) -> Optional[str]: + """Return the state or province name.""" + return self._state_or_province_name + + @property + def locality_name(self) -> Optional[str]: + """Return the locality name.""" + return self._locality_name + + @property + def is_ca(self) -> bool: + """Return whether the certificate is a CA certificate.""" + return self._is_ca + + @property + def add_unique_id_to_subject_name(self) -> bool: + """Return whether to add a unique identifier to the subject name.""" + return self._add_unique_id_to_subject_name + + @property + def additional_critical_extensions(self) -> List[x509.ExtensionType]: + """Return additional critical extensions to be added to the CSR.""" + return self._additional_critical_extensions + + @classmethod + def from_csr( + cls, csr: CertificateSigningRequest, is_ca: bool + ) -> "CertificateRequestAttributes": + """Create CertificateRequestAttributes from a CertificateSigningRequest. + + Args: + csr: The CSR to extract attributes from. + is_ca: Whether a CA certificate is being requested. + + Returns: + CertificateRequestAttributes: The extracted attributes. + """ + return cls( + common_name=csr.common_name, + sans_dns=csr.sans_dns, + sans_ip=csr.sans_ip, + sans_oid=csr.sans_oid, + email_address=csr.email_address, + organization=csr.organization, + organizational_unit=csr.organizational_unit, + country_name=csr.country_name, + state_or_province_name=csr.state_or_province_name, + locality_name=csr.locality_name, + is_ca=is_ca, + add_unique_id_to_subject_name=csr.has_unique_identifier, + additional_critical_extensions=csr.additional_critical_extensions, + ) + + def __eq__(self, other: object) -> bool: + """Check if two CertificateRequestAttributes objects are equal.""" + if not isinstance(other, CertificateRequestAttributes): + return NotImplemented + return ( + self.common_name == other.common_name + and self.sans_dns == other.sans_dns + and self.sans_ip == other.sans_ip + and self.sans_oid == other.sans_oid + and self.email_address == other.email_address + and self.organization == other.organization + and self.organizational_unit == other.organizational_unit + and self.country_name == other.country_name + and self.state_or_province_name == other.state_or_province_name + and self.locality_name == other.locality_name + and self.is_ca == other.is_ca + and self.add_unique_id_to_subject_name == other.add_unique_id_to_subject_name + and self.additional_critical_extensions == other.additional_critical_extensions + ) + + def is_valid(self) -> bool: + """Validate the attributes of the certificate request. + + Returns: + bool: True if the attributes are valid, False otherwise. + """ + if not self.common_name and not self.sans_dns and not self.sans_ip and not self.sans_oid: + logger.warning( + "At least one of common_name, sans_dns, sans_ip, or sans_oid must be provided" + ) + return False + return True + + def generate_csr( + self, + private_key: PrivateKey, + ) -> CertificateSigningRequest: + """Generate a CSR using the current attributes and a private key. + + Args: + private_key (PrivateKey): Private key to sign the CSR. + + Returns: + CertificateSigningRequest: The generated CSR. + """ + return CertificateSigningRequest.generate(self, private_key) + + +@dataclass(frozen=True) +class ProviderCertificate: + """This class represents a certificate provided by the TLS provider.""" + + relation_id: int + certificate: Certificate + certificate_signing_request: CertificateSigningRequest + ca: Certificate + chain: List[Certificate] + revoked: Optional[bool] = None + + def to_json(self) -> str: + """Return the object as a JSON string. + + Returns: + str: JSON representation of the object + """ + return json.dumps( + { + "csr": str(self.certificate_signing_request), + "certificate": str(self.certificate), + "ca": str(self.ca), + "chain": [str(cert) for cert in self.chain], + "revoked": self.revoked, + } + ) + + +@dataclass(frozen=True) +class RequirerCertificateRequest: + """This class represents a certificate signing request requested by a specific TLS requirer.""" + + relation_id: int + certificate_signing_request: CertificateSigningRequest + is_ca: bool + + +class CertificateAvailableEvent(EventBase): + """Charm Event triggered when a TLS certificate is available.""" + + def __init__( + self, + handle: Handle, + certificate: Certificate, + certificate_signing_request: CertificateSigningRequest, + ca: Certificate, + chain: List[Certificate], + ): + super().__init__(handle) + self.certificate = certificate + self.certificate_signing_request = certificate_signing_request + self.ca = ca + self.chain = chain + + def snapshot(self) -> dict: + """Return snapshot.""" + return { + "certificate": str(self.certificate), + "certificate_signing_request": str(self.certificate_signing_request), + "ca": str(self.ca), + "chain": json.dumps([str(certificate) for certificate in self.chain]), + } + + def restore(self, snapshot: dict): + """Restore snapshot.""" + self.certificate = Certificate.from_string(snapshot["certificate"]) + self.certificate_signing_request = CertificateSigningRequest.from_string( + snapshot["certificate_signing_request"] + ) + self.ca = Certificate.from_string(snapshot["ca"]) + chain_strs = json.loads(snapshot["chain"]) + self.chain = [Certificate.from_string(chain_str) for chain_str in chain_strs] + + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join([str(cert) for cert in self.chain]) + + +def generate_private_key( + key_size: int = 2048, + public_exponent: int = 65537, +) -> PrivateKey: + """Generate a private key with the RSA algorithm. + + Args: + key_size (int): Key size in bits, must be at least 2048 bits + public_exponent: Public exponent. + + Returns: + PrivateKey: Private Key + """ + warnings.warn( + "generate_private_key() is deprecated. Use PrivateKey.generate() instead.", + DeprecationWarning, + ) + return PrivateKey.generate(key_size=key_size, public_exponent=public_exponent) + + +def calculate_relative_datetime(target_time: datetime, fraction: float) -> datetime: + """Calculate a datetime that is a given percentage from now to a target time. + + Args: + target_time (datetime): The future datetime to interpolate towards. + fraction (float): Fraction of the interval from now to target_time (0.0-1.0). + 1.0 means return target_time, + 0.9 means return the time after 90% of the interval has passed, + and 0.0 means return now. + """ + if fraction <= 0.0 or fraction > 1.0: + raise ValueError("Invalid fraction. Must be between 0.0 and 1.0") + now = datetime.now(timezone.utc) + time_until_target = target_time - now + return now + time_until_target * fraction + + +def chain_has_valid_order(chain: List[str]) -> bool: + """Check if the chain has a valid order. + + Validates that each certificate in the chain is properly signed by the next certificate. + The chain should be ordered from leaf to root, where each certificate is signed by + the next one in the chain. + + Args: + chain (List[str]): List of certificates in PEM format, ordered from leaf to root + + Returns: + bool: True if the chain has a valid order, False otherwise. + """ + if len(chain) < 2: + return True + + try: + for i in range(len(chain) - 1): + cert = x509.load_pem_x509_certificate(chain[i].encode()) + issuer = x509.load_pem_x509_certificate(chain[i + 1].encode()) + cert.verify_directly_issued_by(issuer) + return True + except (ValueError, TypeError, InvalidSignature): + return False + + +def generate_csr( # noqa: C901 + private_key: PrivateKey, + common_name: str, + sans_dns: Optional[FrozenSet[str]] = frozenset(), + sans_ip: Optional[FrozenSet[str]] = frozenset(), + sans_oid: Optional[FrozenSet[str]] = frozenset(), + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + email_address: Optional[str] = None, + country_name: Optional[str] = None, + locality_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + add_unique_id_to_subject_name: bool = True, +) -> CertificateSigningRequest: + """Generate a CSR using private key and subject. + + Args: + private_key (PrivateKey): Private key + common_name (str): Common name + sans_dns (FrozenSet[str]): DNS Subject Alternative Names + sans_ip (FrozenSet[str]): IP Subject Alternative Names + sans_oid (FrozenSet[str]): OID Subject Alternative Names + organization (Optional[str]): Organization name + organizational_unit (Optional[str]): Organizational unit name + email_address (Optional[str]): Email address + country_name (Optional[str]): Country name + state_or_province_name (Optional[str]): State or province name + locality_name (Optional[str]): Locality name + add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's + subject name. Always leave to "True" when the CSR is used to request certificates + using the tls-certificates relation. + + Returns: + CertificateSigningRequest: CSR + """ + warnings.warn( + "generate_csr() is deprecated. Use CertificateRequestAttributes.generate_csr() or CertificateSigningRequest.generate() instead.", + DeprecationWarning, + ) + return CertificateRequestAttributes( + common_name=common_name, + sans_dns=sans_dns, + sans_ip=sans_ip, + sans_oid=sans_oid, + organization=organization, + organizational_unit=organizational_unit, + email_address=email_address, + country_name=country_name, + state_or_province_name=state_or_province_name, + locality_name=locality_name, + add_unique_id_to_subject_name=add_unique_id_to_subject_name, + ).generate_csr(private_key=private_key) + + +def generate_ca( + private_key: PrivateKey, + validity: timedelta, + common_name: str, + sans_dns: Optional[FrozenSet[str]] = frozenset(), + sans_ip: Optional[FrozenSet[str]] = frozenset(), + sans_oid: Optional[FrozenSet[str]] = frozenset(), + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + email_address: Optional[str] = None, + country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, +) -> Certificate: + """Generate a self signed CA Certificate. + + Args: + private_key: Private key + validity: Certificate validity time + common_name: Common Name that can be an IP or a Full Qualified Domain Name (FQDN). + sans_dns: DNS Subject Alternative Names + sans_ip: IP Subject Alternative Names + sans_oid: OID Subject Alternative Names + organization: Organization name + organizational_unit: Organizational unit name + email_address: Email address + country_name: Certificate Issuing country + state_or_province_name: Certificate Issuing state or province + locality_name: Certificate Issuing locality + + Returns: + CA Certificate. + """ + warnings.warn( + "generate_ca() is deprecated. Use Certificate.generate_self_signed_ca() instead.", + DeprecationWarning, + ) + attributes = CertificateRequestAttributes( + common_name=common_name, + sans_dns=sans_dns, + sans_ip=sans_ip, + sans_oid=sans_oid, + organization=organization, + organizational_unit=organizational_unit, + email_address=email_address, + country_name=country_name, + state_or_province_name=state_or_province_name, + locality_name=locality_name, + is_ca=True, + ) + return Certificate.generate_self_signed_ca(attributes, private_key, validity) + + +def _san_extension( + email_address: Optional[str] = None, + sans_dns: Optional[Collection[str]] = frozenset(), + sans_ip: Optional[Collection[str]] = frozenset(), + sans_oid: Optional[Collection[str]] = frozenset(), +) -> Optional[x509.SubjectAlternativeName]: + sans: List[x509.GeneralName] = [] + if email_address: + # If an e-mail address was provided, it should always be in the SAN + sans.append(x509.RFC822Name(email_address)) + if sans_dns: + sans.extend([x509.DNSName(san) for san in sans_dns]) + if sans_ip: + sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) + if sans_oid: + sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) + if not sans: + return None + return x509.SubjectAlternativeName(sans) + + +def generate_certificate( + csr: CertificateSigningRequest, + ca: Certificate, + ca_private_key: PrivateKey, + validity: timedelta, + is_ca: bool = False, +) -> Certificate: + """Generate a TLS certificate based on a CSR. + + Args: + csr (CertificateSigningRequest): CSR + ca (Certificate): CA Certificate + ca_private_key (PrivateKey): CA private key + validity (timedelta): Certificate validity time + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + Certificate: Certificate + """ + warnings.warn( + "generate_certificate() is deprecated. Use Certificate.generate() instead.", + DeprecationWarning, + ) + return Certificate.generate( + csr=csr, + ca=ca, + ca_private_key=ca_private_key, + validity=validity, + is_ca=is_ca, + ) + + +def _extract_subject_name_attributes( + attributes: CertificateRequestAttributes, +) -> Optional[x509.Name]: + subject_name_attributes = [] + if attributes.common_name: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.COMMON_NAME, attributes.common_name) + ) + if attributes.add_unique_id_to_subject_name: + unique_identifier = uuid.uuid4() + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier)) + ) + if attributes.organization: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, attributes.organization) + ) + if attributes.organizational_unit: + subject_name_attributes.append( + x509.NameAttribute( + x509.NameOID.ORGANIZATIONAL_UNIT_NAME, + attributes.organizational_unit, + ) + ) + if attributes.email_address: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, attributes.email_address) + ) + if attributes.country_name: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.COUNTRY_NAME, attributes.country_name) + ) + if attributes.state_or_province_name: + subject_name_attributes.append( + x509.NameAttribute( + x509.NameOID.STATE_OR_PROVINCE_NAME, + attributes.state_or_province_name, + ) + ) + if attributes.locality_name: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.LOCALITY_NAME, attributes.locality_name) + ) + + if subject_name_attributes: + return x509.Name(subject_name_attributes) + + return None + + +def _generate_certificate_request_extensions( + authority_key_identifier: bytes, + csr: x509.CertificateSigningRequest, + is_ca: bool, +) -> List[x509.Extension]: + """Generate a list of certificate extensions from a CSR and other known information. + + Args: + authority_key_identifier (bytes): Authority key identifier + csr (x509.CertificateSigningRequest): CSR + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + List[x509.Extension]: List of extensions + """ + cert_extensions_list: List[x509.Extension] = [ + x509.Extension( + oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, + value=x509.AuthorityKeyIdentifier( + key_identifier=authority_key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, + value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.BASIC_CONSTRAINTS, + critical=True, + value=x509.BasicConstraints(ca=is_ca, path_length=None), + ), + ] + if sans := _generate_subject_alternative_name_extension(csr): + cert_extensions_list.append(sans) + + if is_ca: + cert_extensions_list.append( + x509.Extension( + ExtensionOID.KEY_USAGE, + critical=True, + value=x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + ) + ) + + existing_oids = {ext.oid for ext in cert_extensions_list} + for extension in csr.extensions: + if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + continue + if extension.oid in existing_oids: + logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) + continue + cert_extensions_list.append(extension) + + return cert_extensions_list + + +def _generate_subject_alternative_name_extension( + csr: x509.CertificateSigningRequest, +) -> Optional[x509.Extension]: + sans: List[x509.GeneralName] = [] + try: + loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) + sans.extend( + [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] + ) + sans.extend( + [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] + ) + sans.extend( + [ + x509.RegisteredID(oid) + for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) + ] + ) + sans.extend( + [ + x509.RFC822Name(name) + for name in loaded_san_ext.value.get_values_for_type(x509.RFC822Name) + ] + ) + except x509.ExtensionNotFound: + pass + # If email is present in the CSR Subject, make sure it is also in the SANS + # to conform to RFC 5280. + email = csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) + if email: + email_rfc822 = x509.RFC822Name(str(email[0].value)) + if email_rfc822 not in sans: + sans.append(email_rfc822) + + return ( + x509.Extension( + oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + critical=False, + value=x509.SubjectAlternativeName(sans), + ) + if sans + else None + ) + + +class CertificatesRequirerCharmEvents(CharmEvents): + """List of events that the TLS Certificates requirer charm can leverage.""" + + certificate_available = EventSource(CertificateAvailableEvent) + + +class TLSCertificatesRequiresV4(Object): + """A class to manage the TLS certificates interface for a unit or app.""" + + on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] + + def __init__( + self, + charm: CharmBase, + relationship_name: str, + certificate_requests: List[CertificateRequestAttributes], + mode: Mode = Mode.UNIT, + refresh_events: List[BoundEvent] = [], + private_key: Optional[PrivateKey] = None, + renewal_relative_time: float = 0.9, + ): + """Create a new instance of the TLSCertificatesRequiresV4 class. + + Args: + charm (CharmBase): The charm instance to relate to. + relationship_name (str): The name of the relation that provides the certificates. + certificate_requests (List[CertificateRequestAttributes]): + A list with the attributes of the certificate requests. + mode (Mode): Whether to use unit or app certificates mode. Default is Mode.UNIT. + In UNIT mode the requirer will place the csr in the unit relation data. + Each unit will manage its private key, + certificate signing request and certificate. + UNIT mode is for use cases where each unit has its own identity. + If you don't know which mode to use, you likely need UNIT. + In APP mode the leader unit will place the csr in the app relation databag. + APP mode is for use cases where the underlying application needs the certificate + for example using it as an intermediate CA to sign other certificates. + The certificate can only be accessed by the leader unit. + refresh_events (List[BoundEvent]): A list of events to trigger a refresh of + the certificates. + private_key (Optional[PrivateKey]): The private key to use for the certificates. + If provided, it will be used instead of generating a new one. + If the key is not valid an exception will be raised. + Using this parameter is discouraged, + having to pass around private keys manually can be a security concern. + Allowing the library to generate and manage the key is the more secure approach. + renewal_relative_time (float): The time to renew the certificate relative to its + expiry. + Default is 0.9, meaning 90% of the validity period. + The minimum value is 0.5, meaning 50% of the validity period. + If an invalid value is provided, an exception will be raised. + """ + super().__init__(charm, relationship_name) + if not JujuVersion.from_environ().has_secrets: + logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") + if not self._mode_is_valid(mode): + raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP") + for certificate_request in certificate_requests: + if not certificate_request.is_valid(): + raise TLSCertificatesError("Invalid certificate request") + self.charm = charm + self.relationship_name = relationship_name + self.certificate_requests = certificate_requests + self.mode = mode + if private_key and not private_key.is_valid(): + raise TLSCertificatesError("Invalid private key") + if renewal_relative_time <= 0.5 or renewal_relative_time > 1.0: + raise TLSCertificatesError( + "Invalid renewal relative time. Must be between 0.5 and 1.0" + ) + self._private_key = private_key + self.renewal_relative_time = renewal_relative_time + self.framework.observe(charm.on[relationship_name].relation_created, self._configure) + self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) + self.framework.observe(charm.on.secret_remove, self._on_secret_remove) + for event in refresh_events: + self.framework.observe(event, self._configure) + self._security_logger = _OWASPLogger(application=f"tls-certificates-{charm.app.name}") + + def _configure(self, _: Optional[EventBase] = None): + """Handle TLS Certificates Relation Data. + + This method is called during any TLS relation event. + It will generate a private key if it doesn't exist yet. + It will send certificate requests if they haven't been sent yet. + It will find available certificates and emit events. + """ + if not self._tls_relation_created(): + logger.debug("TLS relation not created yet.") + return + self._ensure_private_key() + self._cleanup_certificate_requests() + self._send_certificate_requests() + self._find_available_certificates() + + def _mode_is_valid(self, mode: Mode) -> bool: + return mode in [Mode.UNIT, Mode.APP] + + def _validate_secret_exists(self, secret: Secret) -> None: + secret.get_info() # Will raise `SecretNotFoundError` if the secret does not exist + + def _on_secret_remove(self, event: SecretRemoveEvent) -> None: + """Handle Secret Removed Event.""" + try: + # Ensure the secret exists before trying to remove it, otherwise + # the unit could be stuck in an error state. See the docstring of + # `remove_revision` and the below issue for more information. + # https://github.com/juju/juju/issues/19036 + self._validate_secret_exists(event.secret) + event.secret.remove_revision(event.revision) + except SecretNotFoundError: + logger.warning( + "No such secret %s, nothing to remove", + event.secret.label or event.secret.id, + ) + return + + def _on_secret_expired(self, event: SecretExpiredEvent) -> None: + """Handle Secret Expired Event. + + Renews certificate requests and removes the expired secret. + """ + if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-certificate"): + return + try: + csr_str = event.secret.get_content(refresh=True)["csr"] + except ModelError: + logger.error("Failed to get CSR from secret - Skipping") + return + csr = CertificateSigningRequest.from_string(csr_str) + self._renew_certificate_request(csr) + event.secret.remove_all_revisions() + + def sync(self) -> None: + """Sync TLS Certificates Relation Data. + + This method allows the requirer to sync the TLS certificates relation data + without waiting for the refresh events to be triggered. + """ + self._configure() + + def renew_certificate(self, certificate: ProviderCertificate) -> None: + """Request the renewal of the provided certificate.""" + certificate_signing_request = certificate.certificate_signing_request + secret_label = self._get_csr_secret_label(certificate_signing_request) + try: + secret = self.model.get_secret(label=secret_label) + except SecretNotFoundError: + logger.warning("No matching secret found - Skipping renewal") + return + current_csr = secret.get_content(refresh=True).get("csr", "") + if current_csr != str(certificate_signing_request): + logger.warning("No matching CSR found - Skipping renewal") + return + self._renew_certificate_request(certificate_signing_request) + secret.remove_all_revisions() + + def _renew_certificate_request(self, csr: CertificateSigningRequest): + """Remove existing CSR from relation data and create a new one.""" + self._remove_requirer_csr_from_relation_data(csr) + self._send_certificate_requests() + logger.info("Renewed certificate request") + + def _remove_requirer_csr_from_relation_data(self, csr: CertificateSigningRequest) -> None: + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return + if not self.get_csrs_from_requirer_relation_data(): + logger.info("No CSRs in relation data - Doing nothing") + return + app_or_unit = self._get_app_or_unit() + try: + requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) + except DataValidationError: + logger.warning("Invalid relation data - Skipping removal of CSR") + return + new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests) + for requirer_csr in new_relation_data: + if requirer_csr.certificate_signing_request.strip() == str(csr).strip(): + new_relation_data.remove(requirer_csr) + try: + _RequirerData(certificate_signing_requests=new_relation_data).dump( + relation.data[app_or_unit] + ) + logger.info("Removed CSR from relation data") + except ModelError: + logger.warning("Failed to update relation data") + + def _get_app_or_unit(self) -> Union[Application, Unit]: + """Return the unit or app object based on the mode.""" + if self.mode == Mode.UNIT: + return self.model.unit + elif self.mode == Mode.APP: + return self.model.app + raise TLSCertificatesError("Invalid mode") + + @property + def private_key(self) -> Optional[PrivateKey]: + """Return the private key.""" + if self._private_key: + return self._private_key + if not self._private_key_generated(): + return None + secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) + private_key = secret.get_content(refresh=True)["private-key"] + return PrivateKey.from_string(private_key) + + def _ensure_private_key(self) -> None: + """Make sure there is a private key to be used. + + It will make sure there is a private key passed by the charm using the private_key + parameter or generate a new one otherwise. + """ + # Remove the generated private key + # if one has been passed by the charm using the private_key parameter + if self._private_key: + self._remove_private_key_secret() + return + if self._private_key_generated(): + logger.debug("Private key already generated") + return + self._generate_private_key() + + def regenerate_private_key(self) -> None: + """Regenerate the private key. + + Generate a new private key, remove old certificate requests and send new ones. + + Raises: + TLSCertificatesError: If the private key is passed by the charm using the + private_key parameter. + """ + if self._private_key: + raise TLSCertificatesError( + "Private key is passed by the charm through the private_key parameter, " + "this function can't be used" + ) + if not self._private_key_generated(): + logger.warning("No private key to regenerate") + return + self._generate_private_key() + self._cleanup_certificate_requests() + self._send_certificate_requests() + + def _generate_private_key(self) -> None: + """Generate a new private key and store it in a secret. + + This is the case when the private key used is generated by the library. + and not passed by the charm using the private_key parameter. + """ + self._store_private_key_in_secret(generate_private_key()) + logger.info("Private key generated") + + def _private_key_generated(self) -> bool: + """Check if a private key is stored in a secret. + + This is the case when the private key used is generated by the library. + This should not exist when the private key used + is passed by the charm using the private_key parameter. + """ + try: + secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) + secret.get_content(refresh=True) + return True + except SecretNotFoundError: + return False + + def _store_private_key_in_secret(self, private_key: PrivateKey) -> None: + try: + secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) + secret.set_content({"private-key": str(private_key)}) + secret.get_content(refresh=True) + except SecretNotFoundError: + self.charm.unit.add_secret( + content={"private-key": str(private_key)}, + label=self._get_private_key_secret_label(), + ) + + def _remove_private_key_secret(self) -> None: + """Remove the private key secret.""" + try: + secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) + secret.remove_all_revisions() + except SecretNotFoundError: + logger.warning("Private key secret not found, nothing to remove") + + def _csr_matches_certificate_request( + self, certificate_signing_request: CertificateSigningRequest, is_ca: bool + ) -> bool: + for certificate_request in self.certificate_requests: + if certificate_request == CertificateRequestAttributes.from_csr( + certificate_signing_request, + is_ca, + ): + return True + return False + + def _certificate_requested(self, certificate_request: CertificateRequestAttributes) -> bool: + if not self.private_key: + return False + csr = self._certificate_requested_for_attributes(certificate_request) + if not csr: + return False + if not csr.certificate_signing_request.matches_private_key(key=self.private_key): + return False + return True + + def _certificate_requested_for_attributes( + self, + certificate_request: CertificateRequestAttributes, + ) -> Optional[RequirerCertificateRequest]: + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if certificate_request == CertificateRequestAttributes.from_csr( + requirer_csr.certificate_signing_request, + requirer_csr.is_ca, + ): + return requirer_csr + return None + + def get_csrs_from_requirer_relation_data(self) -> List[RequirerCertificateRequest]: + """Return list of requirer's CSRs from relation data.""" + if self.mode == Mode.APP and not self.model.unit.is_leader(): + logger.debug("Not a leader unit - Skipping") + return [] + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return [] + app_or_unit = self._get_app_or_unit() + try: + requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) + except DataValidationError: + logger.warning("Invalid relation data") + return [] + requirer_csrs = [] + for csr in requirer_relation_data.certificate_signing_requests: + requirer_csrs.append( + RequirerCertificateRequest( + relation_id=relation.id, + certificate_signing_request=CertificateSigningRequest.from_string( + csr.certificate_signing_request + ), + is_ca=csr.ca if csr.ca else False, + ) + ) + return requirer_csrs + + def get_provider_certificates(self) -> List[ProviderCertificate]: + """Return list of certificates from the provider's relation data.""" + return self._load_provider_certificates() + + def _load_provider_certificates(self) -> List[ProviderCertificate]: + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return [] + if not relation.app: + logger.debug("No remote app in relation: %s", self.relationship_name) + return [] + try: + provider_relation_data = _ProviderApplicationData.load(relation.data[relation.app]) + except DataValidationError: + logger.warning("Invalid relation data") + return [] + return [ + certificate.to_provider_certificate(relation_id=relation.id) + for certificate in provider_relation_data.certificates + ] + + def _request_certificate(self, csr: CertificateSigningRequest, is_ca: bool) -> None: + """Add CSR to relation data.""" + if self.mode == Mode.APP and not self.model.unit.is_leader(): + logger.debug("Not a leader unit - Skipping") + return + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return + new_csr = _CertificateSigningRequest( + certificate_signing_request=str(csr).strip(), ca=is_ca + ) + app_or_unit = self._get_app_or_unit() + try: + requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) + except DataValidationError: + requirer_relation_data = _RequirerData( + certificate_signing_requests=[], + ) + new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests) + new_relation_data.append(new_csr) + try: + _RequirerData(certificate_signing_requests=new_relation_data).dump( + relation.data[app_or_unit] + ) + logger.info("Certificate signing request added to relation data.") + except ModelError: + logger.warning("Failed to update relation data") + + def _send_certificate_requests(self): + if not self.private_key: + logger.debug("Private key not generated yet.") + return + for certificate_request in self.certificate_requests: + if not self._certificate_requested(certificate_request): + csr = certificate_request.generate_csr( + private_key=self.private_key, + ) + if not csr: + logger.warning("Failed to generate CSR") + continue + self._request_certificate(csr=csr, is_ca=certificate_request.is_ca) + + def get_assigned_certificate( + self, certificate_request: CertificateRequestAttributes + ) -> Tuple[Optional[ProviderCertificate], Optional[PrivateKey]]: + """Get the certificate that was assigned to the given certificate request.""" + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if certificate_request == CertificateRequestAttributes.from_csr( + requirer_csr.certificate_signing_request, + requirer_csr.is_ca, + ): + return self._find_certificate_in_relation_data(requirer_csr), self.private_key + return None, None + + def get_assigned_certificates( + self, + ) -> Tuple[List[ProviderCertificate], Optional[PrivateKey]]: + """Get a list of certificates that were assigned to this or app.""" + assigned_certificates = [] + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if cert := self._find_certificate_in_relation_data(requirer_csr): + assigned_certificates.append(cert) + return assigned_certificates, self.private_key + + def _find_certificate_in_relation_data( + self, csr: RequirerCertificateRequest + ) -> Optional[ProviderCertificate]: + """Return the certificate that matches the given CSR, validated against the private key.""" + if not self.private_key: + return None + for provider_certificate in self.get_provider_certificates(): + if provider_certificate.certificate_signing_request == csr.certificate_signing_request: + if provider_certificate.certificate.is_ca and not csr.is_ca: + logger.warning("Non CA certificate requested, got a CA certificate, ignoring") + continue + elif not provider_certificate.certificate.is_ca and csr.is_ca: + logger.warning("CA certificate requested, got a non CA certificate, ignoring") + continue + if not provider_certificate.certificate.matches_private_key(self.private_key): + logger.warning( + "Certificate does not match the private key. Ignoring invalid certificate." + ) + continue + return provider_certificate + return None + + def _find_available_certificates(self): + """Find available certificates and emit events. + + This method will find certificates that are available for the requirer's CSRs. + If a certificate is found, it will be set as a secret and an event will be emitted. + If a certificate is revoked, the secret will be removed and an event will be emitted. + """ + requirer_csrs = self.get_csrs_from_requirer_relation_data() + csrs = [csr.certificate_signing_request for csr in requirer_csrs] + provider_certificates = self.get_provider_certificates() + for provider_certificate in provider_certificates: + if provider_certificate.certificate_signing_request in csrs: + secret_label = self._get_csr_secret_label( + provider_certificate.certificate_signing_request + ) + if provider_certificate.revoked: + with suppress(SecretNotFoundError): + logger.debug( + "Removing secret with label %s", + secret_label, + ) + secret = self.model.get_secret(label=secret_label) + secret.remove_all_revisions() + else: + if not self._csr_matches_certificate_request( + certificate_signing_request=provider_certificate.certificate_signing_request, + is_ca=provider_certificate.certificate.is_ca, + ): + logger.debug("Certificate requested for different attributes - Skipping") + continue + try: + secret = self.model.get_secret(label=secret_label) + logger.debug("Setting secret with label %s", secret_label) + # Juju < 3.6 will create a new revision even if the content is the same + if secret.get_content(refresh=True).get("certificate", "") == str( + provider_certificate.certificate + ): + logger.debug( + "Secret %s with correct certificate already exists", secret_label + ) + continue + secret.set_content( + content={ + "certificate": str(provider_certificate.certificate), + "csr": str(provider_certificate.certificate_signing_request), + } + ) + secret.set_info( + expire=calculate_relative_datetime( + target_time=provider_certificate.certificate.expiry_time, + fraction=self.renewal_relative_time, + ), + ) + secret.get_content(refresh=True) + except SecretNotFoundError: + logger.debug("Creating new secret with label %s", secret_label) + secret = self.charm.unit.add_secret( + content={ + "certificate": str(provider_certificate.certificate), + "csr": str(provider_certificate.certificate_signing_request), + }, + label=secret_label, + expire=calculate_relative_datetime( + target_time=provider_certificate.certificate.expiry_time, + fraction=self.renewal_relative_time, + ), + ) + self.on.certificate_available.emit( + certificate_signing_request=provider_certificate.certificate_signing_request, + certificate=provider_certificate.certificate, + ca=provider_certificate.ca, + chain=provider_certificate.chain, + ) + + def _cleanup_certificate_requests(self): + """Clean up certificate requests. + + Remove any certificate requests that falls into one of the following categories: + - The CSR attributes do not match any of the certificate requests defined in + the charm's certificate_requests attribute. + - The CSR public key does not match the private key. + """ + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if not self._csr_matches_certificate_request( + certificate_signing_request=requirer_csr.certificate_signing_request, + is_ca=requirer_csr.is_ca, + ): + self._remove_requirer_csr_from_relation_data( + requirer_csr.certificate_signing_request + ) + logger.info( + "Removed CSR from relation data because it did not match any certificate request" # noqa: E501 + ) + elif ( + self.private_key + and not requirer_csr.certificate_signing_request.matches_private_key( + self.private_key + ) + ): + self._remove_requirer_csr_from_relation_data( + requirer_csr.certificate_signing_request + ) + logger.info( + "Removed CSR from relation data because it did not match the private key" # noqa: E501 + ) + + def _tls_relation_created(self) -> bool: + relation = self.model.get_relation(self.relationship_name) + if not relation: + return False + return True + + def _get_private_key_secret_label(self) -> str: + if self.mode == Mode.UNIT: + return f"{LIBID}-private-key-{self._get_unit_number()}-{self.relationship_name}" + elif self.mode == Mode.APP: + return f"{LIBID}-private-key-{self.relationship_name}" + else: + raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.") + + def _get_csr_secret_label(self, csr: CertificateSigningRequest) -> str: + csr_in_sha256_hex = csr.get_sha256_hex() + if self.mode == Mode.UNIT: + return f"{LIBID}-certificate-{self._get_unit_number()}-{csr_in_sha256_hex}" + elif self.mode == Mode.APP: + return f"{LIBID}-certificate-{csr_in_sha256_hex}" + else: + raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.") + + def _get_unit_number(self) -> str: + return self.model.unit.name.split("/")[1] + + +class TLSCertificatesProvidesV4(Object): + """TLS certificates provider class to be instantiated by TLS certificates providers.""" + + def __init__(self, charm: CharmBase, relationship_name: str): + super().__init__(charm, relationship_name) + self.framework.observe(charm.on[relationship_name].relation_joined, self._configure) + self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) + self.framework.observe(charm.on.update_status, self._configure) + self.charm = charm + self.relationship_name = relationship_name + self._security_logger = _OWASPLogger(application=f"tls-certificates-{charm.app.name}") + + def _configure(self, _: EventBase) -> None: + """Handle update status and tls relation changed events. + + This is a common hook triggered on a regular basis. + + Revoke certificates for which no csr exists + """ + if not self.model.unit.is_leader(): + return + self._remove_certificates_for_which_no_csr_exists() + + def _remove_certificates_for_which_no_csr_exists(self) -> None: + provider_certificates = self.get_provider_certificates() + requirer_csrs = [ + request.certificate_signing_request for request in self.get_certificate_requests() + ] + for provider_certificate in provider_certificates: + if provider_certificate.certificate_signing_request not in requirer_csrs: + tls_relation = self._get_tls_relations( + relation_id=provider_certificate.relation_id + ) + self._remove_provider_certificate( + certificate=provider_certificate.certificate, + relation=tls_relation[0], + ) + + def _get_tls_relations(self, relation_id: Optional[int] = None) -> List[Relation]: + return ( + [ + relation + for relation in self.model.relations[self.relationship_name] + if relation.id == relation_id + ] + if relation_id is not None + else self.model.relations.get(self.relationship_name, []) + ) + + def get_certificate_requests( + self, relation_id: Optional[int] = None + ) -> List[RequirerCertificateRequest]: + """Load certificate requests from the relation data.""" + relations = self._get_tls_relations(relation_id) + requirer_csrs: List[RequirerCertificateRequest] = [] + for relation in relations: + for unit in relation.units: + requirer_csrs.extend(self._load_requirer_databag(relation, unit)) + requirer_csrs.extend(self._load_requirer_databag(relation, relation.app)) + return requirer_csrs + + def _load_requirer_databag( + self, relation: Relation, unit_or_app: Union[Application, Unit] + ) -> List[RequirerCertificateRequest]: + try: + requirer_relation_data = _RequirerData.load(relation.data.get(unit_or_app, {})) + except DataValidationError: + logger.debug("Invalid requirer relation data for %s", unit_or_app.name) + return [] + return [ + RequirerCertificateRequest( + relation_id=relation.id, + certificate_signing_request=CertificateSigningRequest.from_string( + csr.certificate_signing_request + ), + is_ca=csr.ca if csr.ca else False, + ) + for csr in requirer_relation_data.certificate_signing_requests + ] + + def _add_provider_certificate( + self, + relation: Relation, + provider_certificate: ProviderCertificate, + ) -> None: + chain = [str(certificate) for certificate in provider_certificate.chain] + if chain[0] != str(provider_certificate.certificate): + logger.warning( + "The order of the chain from the TLS Certificates Provider is incorrect. " + "The leaf certificate should be the first element of the chain." + ) + elif not chain_has_valid_order(chain): + logger.warning( + "The order of the chain from the TLS Certificates Provider is partially incorrect." + ) + new_certificate = _Certificate( + certificate=str(provider_certificate.certificate), + certificate_signing_request=str(provider_certificate.certificate_signing_request), + ca=str(provider_certificate.ca), + chain=chain, + ) + provider_certificates = self._load_provider_certificates(relation) + if new_certificate in provider_certificates: + logger.info("Certificate already in relation data - Doing nothing") + return + provider_certificates.append(new_certificate) + self._dump_provider_certificates(relation=relation, certificates=provider_certificates) + + def _load_provider_certificates(self, relation: Relation) -> List[_Certificate]: + try: + provider_relation_data = _ProviderApplicationData.load(relation.data[self.charm.app]) + except DataValidationError: + logger.debug("Invalid provider relation data") + return [] + return copy.deepcopy(provider_relation_data.certificates) + + def _dump_provider_certificates(self, relation: Relation, certificates: List[_Certificate]): + try: + _ProviderApplicationData(certificates=certificates).dump(relation.data[self.model.app]) + logger.info("Certificate relation data updated") + except ModelError: + logger.warning("Failed to update relation data") + + def _remove_provider_certificate( + self, + relation: Relation, + certificate: Optional[Certificate] = None, + certificate_signing_request: Optional[CertificateSigningRequest] = None, + ) -> None: + """Remove certificate based on certificate or certificate signing request.""" + provider_certificates = self._load_provider_certificates(relation) + for provider_certificate in provider_certificates: + if certificate and provider_certificate.certificate == str(certificate): + provider_certificates.remove(provider_certificate) + if ( + certificate_signing_request + and provider_certificate.certificate_signing_request + == str(certificate_signing_request) + ): + provider_certificates.remove(provider_certificate) + self._dump_provider_certificates(relation=relation, certificates=provider_certificates) + + def revoke_all_certificates(self) -> None: + """Revoke all certificates of this provider. + + This method is meant to be used when the Root CA has changed. + """ + if not self.model.unit.is_leader(): + logger.warning("Unit is not a leader - will not set relation data") + return + relations = self._get_tls_relations() + for relation in relations: + provider_certificates = self._load_provider_certificates(relation) + for certificate in provider_certificates: + certificate.revoked = True + self._dump_provider_certificates(relation=relation, certificates=provider_certificates) + self._security_logger.log_event( + event="all_certificates_revoked", + level=logging.WARNING, + description="All certificates revoked", + ) + + def set_relation_certificate( + self, + provider_certificate: ProviderCertificate, + ) -> None: + """Add certificates to relation data. + + Args: + provider_certificate (ProviderCertificate): ProviderCertificate object + + Returns: + None + """ + if not self.model.unit.is_leader(): + logger.warning("Unit is not a leader - will not set relation data") + return + certificates_relation = self.model.get_relation( + relation_name=self.relationship_name, relation_id=provider_certificate.relation_id + ) + if not certificates_relation: + raise TLSCertificatesError(f"Relation {self.relationship_name} does not exist") + self._remove_provider_certificate( + relation=certificates_relation, + certificate_signing_request=provider_certificate.certificate_signing_request, + ) + self._add_provider_certificate( + relation=certificates_relation, + provider_certificate=provider_certificate, + ) + self._security_logger.log_event( + event="certificate_provided", + level=logging.INFO, + description="Certificate provided to requirer", + relation_id=str(provider_certificate.relation_id), + common_name=provider_certificate.certificate.common_name, + ) + + def get_issued_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued (non revoked) certificates. + + Returns: + List: List of ProviderCertificate objects + """ + if not self.model.unit.is_leader(): + logger.warning("Unit is not a leader - will not read relation data") + return [] + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + return [certificate for certificate in provider_certificates if not certificate.revoked] + + def get_provider_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued certificates.""" + certificates: List[ProviderCertificate] = [] + relations = self._get_tls_relations(relation_id) + for relation in relations: + if not relation.app: + logger.warning("Relation %s does not have an application", relation.id) + continue + for certificate in self._load_provider_certificates(relation): + certificates.append(certificate.to_provider_certificate(relation_id=relation.id)) + return certificates + + def get_unsolicited_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return provider certificates for which no certificate requests exists. + + Those certificates should be revoked. + """ + unsolicited_certificates: List[ProviderCertificate] = [] + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + requirer_csrs = self.get_certificate_requests(relation_id=relation_id) + list_of_csrs = [csr.certificate_signing_request for csr in requirer_csrs] + for certificate in provider_certificates: + if certificate.certificate_signing_request not in list_of_csrs: + unsolicited_certificates.append(certificate) + return unsolicited_certificates + + def get_outstanding_certificate_requests( + self, relation_id: Optional[int] = None + ) -> List[RequirerCertificateRequest]: + """Return CSR's for which no certificate has been issued. + + Args: + relation_id (int): Relation id + + Returns: + list: List of RequirerCertificateRequest objects. + """ + requirer_csrs = self.get_certificate_requests(relation_id=relation_id) + outstanding_csrs: List[RequirerCertificateRequest] = [] + for relation_csr in requirer_csrs: + if not self._certificate_issued_for_csr( + csr=relation_csr.certificate_signing_request, + relation_id=relation_id, + ): + outstanding_csrs.append(relation_csr) + return outstanding_csrs + + def _certificate_issued_for_csr( + self, csr: CertificateSigningRequest, relation_id: Optional[int] + ) -> bool: + """Check whether a certificate has been issued for a given CSR.""" + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) + for issued_certificate in issued_certificates_per_csr: + if issued_certificate.certificate_signing_request == csr: + return csr.matches_certificate(issued_certificate.certificate) + return False diff --git a/dovecot-charm/pyproject.toml b/dovecot-charm/pyproject.toml index 11187c5..ca14524 100644 --- a/dovecot-charm/pyproject.toml +++ b/dovecot-charm/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "charmlibs-apt==1.0.0.post0", "jinja2", "pydantic[email]>=2.12.5", + "cryptography>=46.0.6", "charmlibs-systemd==1.0.0", ] @@ -120,6 +121,13 @@ lint.per-file-ignores."tests/*" = [ "D417", "S", ] +lint.per-file-ignores."src/charms/*" = [ + "B006", + "B028", + "RUF100", + "S101", + "SIM103", +] lint.flake8-copyright.author = "Canonical Ltd." lint.flake8-copyright.min-file-size = 1 lint.flake8-copyright.notice-rgx = "Copyright\\s\\d{4}([-,]\\d{4})*\\s+" diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index 0c2641c..a6bb3fa 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -18,6 +18,11 @@ from ops.main import main from ops.model import BlockedStatus, MaintenanceStatus +from charmlibs.interfaces.tls_certificates import ( + CertificateAvailableEvent, + CertificateRequestAttributes, + TLSCertificatesRequiresV4, +) from constants import ( DOVECOT_CONF_TARGET, DOVECOT_CONF_TEMPLATE, @@ -62,6 +67,28 @@ def __init__(self, *args): loader=jinja2.FileSystemLoader(TEMPLATES_DIR), autoescape=True ) + # TLS certificates directory + self.tls_cert_dir = Path("/etc/dovecot/private") + + # TLS certificates integration + self._tls = None + mailname = self.config.get("mailname", "") + if mailname: + self._tls = TLSCertificatesRequiresV4( + charm=self, + relationship_name="certificates", + certificate_requests=[ + CertificateRequestAttributes( + common_name=mailname, + sans_dns=frozenset([mailname]), + ) + ], + refresh_events=[self.on.config_changed], + ) + self.framework.observe( + self._tls.on.certificate_available, self._on_certificate_available + ) + def get_units(self) -> typing.List[str]: """Return a list of all units in the application. @@ -248,6 +275,35 @@ def _on_clear_queue_action(self, event): logger.exception(f"Failed to clear Postfix queue: {e.stderr}") event.fail(f"Failed to run postsuper: {e.stderr}") + def _on_certificate_available(self, event: CertificateAvailableEvent): + """Handle TLS certificate available event.""" + mailname = self.config.get("mailname", "") + if not mailname: + logger.warning("Certificate available but mailname not configured") + return + + self.tls_cert_dir.mkdir(parents=True, exist_ok=True) + + cert_path = self.tls_cert_dir / f"{mailname}.pem" + key_path = self.tls_cert_dir / f"{mailname}.key" + + cert_content = str(event.certificate.certificate) + if event.certificate.ca: + cert_content += "\n" + str(event.certificate.ca) + + cert_path.write_text(cert_content) + cert_path.chmod(0o644) + logger.info(f"Certificate written to {cert_path}") + + private_key = self._tls.private_key + if private_key: + key_path.write_text(str(private_key)) + key_path.chmod(0o600) + logger.info(f"Private key written to {key_path}") + + if systemd.service_reload("dovecot"): + logger.info("Dovecot service reloaded with new TLS certificate") + if __name__ == "__main__": # pragma: nocover main(DovecotCharm) diff --git a/dovecot-charm/src/charms/tls_certificates_interface/v4/tls_certificates.py b/dovecot-charm/src/charms/tls_certificates_interface/v4/tls_certificates.py new file mode 100644 index 0000000..b779c7c --- /dev/null +++ b/dovecot-charm/src/charms/tls_certificates_interface/v4/tls_certificates.py @@ -0,0 +1,2525 @@ +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Legacy Charmhub-hosted lib, deprecated in favour of ``charmlibs.interfaces.tls_certificates``. + +WARNING: This library is deprecated. +It will not receive feature updates or bugfixes. +``charmlibs.interfaces.tls_certificates`` 1.0 is a bug-for-bug compatible migration of this library. + +To migrate: +1. Add 'charmlibs-interfaces-tls-certificates~=1.0' to your charm's dependencies, + and remove this Charmhub-hosted library from your charm. +2. You can also remove any dependencies added to your charm only because of this library. +3. Replace `from charms.tls_certificates_interface.v4 import tls_certificates` + with `from charmlibs.interfaces import tls_certificates`. + +Read more: +- https://documentation.ubuntu.com/charmlibs +- https://pypi.org/project/charmlibs-interfaces-tls-certificates + +--- + +Charm library for managing TLS certificates (V4). + +This library contains the Requires and Provides classes for handling the tls-certificates +interface. + +Pre-requisites: + - Juju >= 3.0 + - cryptography >= 43.0.0 + - pydantic >= 1.0 + +Learn more on how-to use the TLS Certificates interface library by reading the documentation: +- https://charmhub.io/tls-certificates-interface/ + +""" # noqa: D214, D405, D411, D416 + +import copy +import ipaddress +import json +import logging +import uuid +import warnings +from contextlib import suppress +from dataclasses import asdict, dataclass, field +from datetime import datetime, timedelta, timezone +from enum import Enum +from typing import ( + Collection, + Dict, + FrozenSet, + List, + MutableMapping, + Optional, + Set, + Tuple, + Union, +) + +import pydantic +from cryptography import x509 +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.types import CertificateIssuerPrivateKeyTypes +from cryptography.x509.oid import ExtensionOID, NameOID +from ops import BoundEvent, CharmBase, CharmEvents, Secret, SecretExpiredEvent, SecretRemoveEvent +from ops.framework import EventBase, EventSource, Handle, Object +from ops.jujuversion import JujuVersion +from ops.model import Application, ModelError, Relation, SecretNotFoundError, Unit + +# The unique Charmhub library identifier, never change it +LIBID = "afd8c2bccf834997afce12c2706d2ede" + +# Increment this major API version when introducing breaking changes +LIBAPI = 4 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 27 + +PYDEPS = [ + "cryptography>=43.0.0", + "pydantic", +] +IS_PYDANTIC_V1 = int(pydantic.version.VERSION.split(".")[0]) < 2 + +logger = logging.getLogger(__name__) + +NESTED_JSON_KEY = "owasp_event" + + +@dataclass +class _OWASPLogEvent: + """OWASP-compliant log event.""" + + datetime: str + event: str + level: str + description: str + type: str = "security" + labels: Dict[str, str] = field(default_factory=dict) + + def to_json(self) -> str: + return json.dumps(self.to_dict(), ensure_ascii=False) + + def to_dict(self) -> Dict: + log_event = dict(asdict(self), **self.labels) + log_event.pop("labels", None) + return {k: v for k, v in log_event.items() if v is not None} + + +class _OWASPLogger: + """OWASP-compliant logger for security events.""" + + def __init__(self, application: Optional[str] = None): + self.application = application + self._logger = logging.getLogger(__name__) + + def log_event(self, event: str, level: int, description: str, **labels: str): + if self.application and "application" not in labels: + labels["application"] = self.application + log = _OWASPLogEvent( + datetime=datetime.now(timezone.utc).astimezone().isoformat(), + event=event, + level=logging.getLevelName(level), + description=description, + labels=labels, + ) + self._logger.log(level, log.to_json(), extra={NESTED_JSON_KEY: log.to_dict()}) + + +class TLSCertificatesError(Exception): + """Base class for custom errors raised by this library.""" + + +class DataValidationError(TLSCertificatesError): + """Raised when data validation fails.""" + + +class _DatabagModel(pydantic.BaseModel): + """Base databag model. + + Supports both pydantic v1 and v2. + """ + + if IS_PYDANTIC_V1: + + class Config: + """Pydantic config.""" + + # ignore any extra fields in the databag + extra = "ignore" + """Ignore any extra fields in the databag.""" + allow_population_by_field_name = True + """Allow instantiating this class by field name (instead of forcing alias).""" + + _NEST_UNDER = None + + model_config = pydantic.ConfigDict( + # tolerate additional keys in databag + extra="ignore", + # Allow instantiating this class by field name (instead of forcing alias). + populate_by_name=True, + # Custom config key: whether to nest the whole datastructure (as json) + # under a field or spread it out at the toplevel. + _NEST_UNDER=None, + ) # type: ignore + """Pydantic config.""" + + @classmethod + def load(cls, databag: MutableMapping): + """Load this model from a Juju databag.""" + if IS_PYDANTIC_V1: + return cls._load_v1(databag) + nest_under = cls.model_config.get("_NEST_UNDER") + if nest_under: + return cls.model_validate(json.loads(databag[nest_under])) + + try: + data = { + k: json.loads(v) + for k, v in databag.items() + # Don't attempt to parse model-external values + if k in {(f.alias or n) for n, f in cls.model_fields.items()} + } + except json.JSONDecodeError as e: + msg = f"invalid databag contents: expecting json. {databag}" + logger.error(msg) + raise DataValidationError(msg) from e + + try: + return cls.model_validate_json(json.dumps(data)) + except pydantic.ValidationError as e: + msg = f"failed to validate databag: {databag}" + logger.debug(msg, exc_info=True) + raise DataValidationError(msg) from e + + @classmethod + def _load_v1(cls, databag: MutableMapping): + """Load implementation for pydantic v1.""" + if cls._NEST_UNDER: + return cls.parse_obj(json.loads(databag[cls._NEST_UNDER])) + + try: + data = { + k: json.loads(v) + for k, v in databag.items() + # Don't attempt to parse model-external values + if k in {f.alias for f in cls.__fields__.values()} + } + except json.JSONDecodeError as e: + msg = f"invalid databag contents: expecting json. {databag}" + logger.error(msg) + raise DataValidationError(msg) from e + + try: + return cls.parse_raw(json.dumps(data)) # type: ignore + except pydantic.ValidationError as e: + msg = f"failed to validate databag: {databag}" + logger.debug(msg, exc_info=True) + raise DataValidationError(msg) from e + + def dump(self, databag: Optional[MutableMapping] = None, clear: bool = True): + """Write the contents of this model to Juju databag. + + Args: + databag: The databag to write to. + clear: Whether to clear the databag before writing. + + Returns: + MutableMapping: The databag. + """ + if IS_PYDANTIC_V1: + return self._dump_v1(databag, clear) + if clear and databag: + databag.clear() + + if databag is None: + databag = {} + nest_under = self.model_config.get("_NEST_UNDER") + if nest_under: + databag[nest_under] = self.model_dump_json( + by_alias=True, + # skip keys whose values are default + exclude_defaults=True, + ) + return databag + + dct = self.model_dump(mode="json", by_alias=True, exclude_defaults=True) + databag.update({k: json.dumps(v) for k, v in dct.items()}) + return databag + + def _dump_v1(self, databag: Optional[MutableMapping] = None, clear: bool = True): + """Dump implementation for pydantic v1.""" + if clear and databag: + databag.clear() + + if databag is None: + databag = {} + + if self._NEST_UNDER: + databag[self._NEST_UNDER] = self.json(by_alias=True, exclude_defaults=True) + return databag + + dct = json.loads(self.json(by_alias=True, exclude_defaults=True)) + databag.update({k: json.dumps(v) for k, v in dct.items()}) + + return databag + + +class _Certificate(pydantic.BaseModel): + """Certificate model.""" + + ca: str + certificate_signing_request: str + certificate: str + chain: Optional[List[str]] = None + revoked: Optional[bool] = None + + def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate": + """Convert to a ProviderCertificate.""" + return ProviderCertificate( + relation_id=relation_id, + certificate=Certificate.from_string(self.certificate), + certificate_signing_request=CertificateSigningRequest.from_string( + self.certificate_signing_request + ), + ca=Certificate.from_string(self.ca), + chain=[Certificate.from_string(certificate) for certificate in self.chain] + if self.chain + else [], + revoked=self.revoked, + ) + + +class _CertificateSigningRequest(pydantic.BaseModel): + """Certificate signing request model.""" + + certificate_signing_request: str + ca: Optional[bool] + + +class _ProviderApplicationData(_DatabagModel): + """Provider application data model.""" + + certificates: List[_Certificate] = [] + + +class _RequirerData(_DatabagModel): + """Requirer data model. + + The same model is used for the unit and application data. + """ + + certificate_signing_requests: List[_CertificateSigningRequest] = [] + + +class Mode(Enum): + """Enum representing the mode of the certificate request. + + UNIT (default): Request a certificate for the unit. + Each unit will manage its private key, + certificate signing request and certificate. + APP: Request a certificate for the application. + Only the leader unit will manage the private key, certificate signing request + and certificate. + """ + + UNIT = 1 + APP = 2 + + +class PrivateKey: + """This class represents a private key.""" + + def __init__( + self, raw: Optional[str] = None, x509_object: Optional[rsa.RSAPrivateKey] = None + ) -> None: + """Initialize the PrivateKey object. + + If both raw and x509_object are provided, x509_object takes precedence. + """ + if x509_object: + self._private_key = x509_object + elif raw: + self._private_key = serialization.load_pem_private_key( + raw.encode(), + password=None, + ) + else: + raise ValueError("Either raw private key string or x509_object must be provided") + + @property + def raw(self) -> str: + """Return the PEM-formatted string representation of the private key.""" + return str(self) + + def __str__(self): + """Return the private key as a string in PEM format.""" + return ( + self._private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + .decode() + .strip() + ) + + def __hash__(self): + """Return the hash of the private key.""" + return hash(self.raw) + + @classmethod + def from_string(cls, private_key: str) -> "PrivateKey": + """Create a PrivateKey object from a private key.""" + return cls(raw=private_key) + + def is_valid(self) -> bool: + """Validate that the private key is PEM-formatted, RSA, and at least 2048 bits.""" + try: + if not isinstance(self._private_key, rsa.RSAPrivateKey): + logger.warning("Private key is not an RSA key") + return False + + if self._private_key.key_size < 2048: + logger.warning("RSA key size is less than 2048 bits") + return False + + return True + except ValueError: + logger.warning("Invalid private key format") + return False + + @classmethod + def generate(cls, key_size: int = 2048, public_exponent: int = 65537) -> "PrivateKey": + """Generate a new RSA private key. + + Args: + key_size: The size of the key in bits. + public_exponent: The public exponent of the key. + + Returns: + PrivateKey: The generated private key. + """ + private_key = rsa.generate_private_key( + public_exponent=public_exponent, + key_size=key_size, + ) + _OWASPLogger().log_event( + event="private_key_generated", + level=logging.INFO, + description="Private key generated", + key_size=str(key_size), + ) + return PrivateKey(x509_object=private_key) + + def __eq__(self, other: object) -> bool: + """Check if two PrivateKey objects are equal.""" + if not isinstance(other, PrivateKey): + return NotImplemented + return self.raw == other.raw + + +class Certificate: + """This class represents a certificate.""" + + _cert: x509.Certificate + + def __init__( + self, + raw: Optional[str] = None, # Must remain first argument for backwards compatibility + # Old Interface fields (ignored) + common_name: Optional[str] = None, + expiry_time: Optional[datetime] = None, + validity_start_time: Optional[datetime] = None, + is_ca: Optional[bool] = None, + sans_dns: Optional[Set[str]] = None, + sans_ip: Optional[Set[str]] = None, + sans_oid: Optional[Set[str]] = None, + email_address: Optional[str] = None, + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, + # End Old Interface fields + x509_object: Optional[x509.Certificate] = None, + ) -> None: + """Initialize the Certificate object. + + This initializer must maintain the old interface while also allowing + instantiation from an existing x509_object. It ignores all fields + other than raw and x509_object, preferring x509_object. + """ + if x509_object: + self._cert = x509_object + elif raw: + self._cert = x509.load_pem_x509_certificate(data=raw.encode()) + else: + raise ValueError("Either raw certificate string or x509_object must be provided") + + @property + def raw(self) -> str: + """Return the PEM-formatted string representation of the certificate.""" + return str(self) + + @property + def common_name(self) -> str: + """Return the common name of the certificate.""" + # We maintain compatibility with the old interface by returning + # an empty string if no common name is set. + common_name = self._cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME) + return str(common_name[0].value) if common_name else "" + + @property + def expiry_time(self) -> datetime: + """Return the expiry time of the certificate.""" + return self._cert.not_valid_after_utc + + @property + def validity_start_time(self) -> datetime: + """Return the validity start time of the certificate.""" + return self._cert.not_valid_before_utc + + @property + def is_ca(self) -> bool: + """Return whether the certificate is a CA certificate.""" + try: + return self._cert.extensions.get_extension_for_oid( + ExtensionOID.BASIC_CONSTRAINTS + ).value.ca # type: ignore[reportAttributeAccessIssue] + except x509.ExtensionNotFound: + return False + + @property + def sans_dns(self) -> Optional[Set[str]]: + """Return the DNS Subject Alternative Names of the certificate.""" + with suppress(x509.ExtensionNotFound): + sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san) for san in sans.get_values_for_type(x509.DNSName)} + return None + + @property + def sans_ip(self) -> Optional[Set[str]]: + """Return the IP Subject Alternative Names of the certificate.""" + with suppress(x509.ExtensionNotFound): + sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san) for san in sans.get_values_for_type(x509.IPAddress)} + return None + + @property + def sans_oid(self) -> Optional[Set[str]]: + """Return the OID Subject Alternative Names of the certificate.""" + with suppress(x509.ExtensionNotFound): + sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san.dotted_string) for san in sans.get_values_for_type(x509.RegisteredID)} + return None + + @property + def email_address(self) -> Optional[str]: + """Return the email address of the certificate.""" + email_address = self._cert.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) + return str(email_address[0].value) if email_address else None + + @property + def organization(self) -> Optional[str]: + """Return the organization name of the certificate.""" + organization = self._cert.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) + return str(organization[0].value) if organization else None + + @property + def organizational_unit(self) -> Optional[str]: + """Return the organizational unit name of the certificate.""" + organizational_unit = self._cert.subject.get_attributes_for_oid( + NameOID.ORGANIZATIONAL_UNIT_NAME + ) + return str(organizational_unit[0].value) if organizational_unit else None + + @property + def country_name(self) -> Optional[str]: + """Return the country name of the certificate.""" + country_name = self._cert.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) + return str(country_name[0].value) if country_name else None + + @property + def state_or_province_name(self) -> Optional[str]: + """Return the state or province name of the certificate.""" + state_or_province_name = self._cert.subject.get_attributes_for_oid( + NameOID.STATE_OR_PROVINCE_NAME + ) + return str(state_or_province_name[0].value) if state_or_province_name else None + + @property + def locality_name(self) -> Optional[str]: + """Return the locality name of the certificate.""" + locality_name = self._cert.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) + return str(locality_name[0].value) if locality_name else None + + def __str__(self) -> str: + """Return the certificate as a string.""" + return self._cert.public_bytes(serialization.Encoding.PEM).decode().strip() + + def __eq__(self, other: object) -> bool: + """Check if two Certificate objects are equal.""" + if not isinstance(other, Certificate): + return NotImplemented + return self.raw == other.raw + + @classmethod + def from_string(cls, certificate: str) -> "Certificate": + """Create a Certificate object from a certificate.""" + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + except ValueError as e: + logger.error("Could not load certificate: %s", e) + raise TLSCertificatesError("Could not load certificate") + + return cls(x509_object=certificate_object) + + def matches_private_key(self, private_key: PrivateKey) -> bool: + """Check if this certificate matches a given private key. + + Args: + private_key (PrivateKey): The private key to validate against. + + Returns: + bool: True if the certificate matches the private key, False otherwise. + """ + try: + cert_public_key = self._cert.public_key() + key_public_key = private_key._private_key.public_key() + + if not isinstance(cert_public_key, rsa.RSAPublicKey): + logger.warning("Certificate does not use RSA public key") + return False + + if not isinstance(key_public_key, rsa.RSAPublicKey): + logger.warning("Private key is not an RSA key") + return False + + return cert_public_key.public_numbers() == key_public_key.public_numbers() + except Exception as e: + logger.warning("Failed to validate certificate and private key match: %s", e) + return False + + @classmethod + def generate( + cls, + csr: "CertificateSigningRequest", + ca: "Certificate", + ca_private_key: "PrivateKey", + validity: timedelta, + is_ca: bool = False, + ) -> "Certificate": + """Generate a certificate from a CSR signed by the given CA and CA private key. + + Args: + csr: The certificate signing request. + ca: The CA certificate. + ca_private_key: The CA private key. + validity: The validity period of the certificate. + is_ca: Whether the generated certificate is a CA certificate. + + Returns: + Certificate: The generated certificate. + """ + # Ideally, this would be the constructor, but we can't add new + # required parameters to the constructor without breaking backwards + # compatibility. + private_key = serialization.load_pem_private_key( + str(ca_private_key).encode(), password=None + ) + assert isinstance(private_key, CertificateIssuerPrivateKeyTypes) + + # Create a certificate builder + cert_builder = x509.CertificateBuilder( + subject_name=csr._csr.subject, + # issuer_name=ca._cert.subject, # TODO: Validate this is correct, the old code used `issuer` + issuer_name=ca._cert.issuer, + public_key=csr._csr.public_key(), + serial_number=x509.random_serial_number(), + not_valid_before=datetime.now(timezone.utc), + not_valid_after=datetime.now(timezone.utc) + validity, + ) + extensions = _generate_certificate_request_extensions( + authority_key_identifier=ca._cert.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier + ).value.key_identifier, + csr=csr._csr, + is_ca=is_ca, + ) + for extension in extensions: + try: + cert_builder = cert_builder.add_extension(extension.value, extension.critical) + except ValueError as e: + logger.error("Could not add extension to certificate: %s", e) + raise TLSCertificatesError("Could not add extension to certificate") from e + + # Sign the certificate with the CA's private key + cert = cert_builder.sign(private_key=private_key, algorithm=hashes.SHA256()) + _OWASPLogger().log_event( + event="certificate_generated", + level=logging.INFO, + description="Certificate generated from CSR", + common_name=csr.common_name, + is_ca=str(is_ca), + validity_days=str(validity.days), + ) + + return cls(x509_object=cert) + + @classmethod + def generate_self_signed_ca( + cls, + attributes: "CertificateRequestAttributes", + private_key: PrivateKey, + validity: timedelta, + ) -> "Certificate": + """Generate a self-signed CA certificate. + + Args: + attributes: The certificate request attributes. + private_key: The private key to sign the CA certificate. + validity: The validity period of the CA certificate. + + Returns: + Certificate: The generated CA certificate. + """ + assert isinstance(private_key._private_key, rsa.RSAPrivateKey) + + public_key = private_key._private_key.public_key() + + builder = x509.CertificateBuilder( + public_key=public_key, + serial_number=x509.random_serial_number(), + not_valid_before=datetime.now(timezone.utc), + not_valid_after=datetime.now(timezone.utc) + validity, + ) + + if subject_name := _extract_subject_name_attributes(attributes): + builder = builder.subject_name(subject_name).issuer_name(subject_name) + + builder = ( + builder.add_extension( + x509.SubjectKeyIdentifier.from_public_key(public_key), critical=False + ) + .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True) + .add_extension( + x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + key_cert_sign=True, + key_agreement=False, + content_commitment=False, + data_encipherment=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + ) + + if san_extension := _san_extension( + email_address=attributes.email_address, + sans_dns=attributes.sans_dns, + sans_ip=attributes.sans_ip, + sans_oid=attributes.sans_oid, + ): + builder = builder.add_extension(san_extension, critical=False) + + cert = cls(x509_object=builder.sign(private_key._private_key, algorithm=hashes.SHA256())) + + _OWASPLogger().log_event( + event="ca_certificate_generated", + level=logging.INFO, + description="CA certificate generated", + common_name=cert.common_name, + validity_days=str(validity.days), + ) + + return cert + + def __hash__(self): + """Return the hash of the private key.""" + return hash(self.raw) + + +class CertificateSigningRequest: + """A representation of the certificate signing request.""" + + _csr: x509.CertificateSigningRequest + + def __init__( + self, + raw: Optional[str] = None, # Must remain first argument for backwards compatibility + # Old Interface fields (ignored) + common_name: Optional[str] = None, + sans_dns: Optional[Set[str]] = None, + sans_ip: Optional[Set[str]] = None, + sans_oid: Optional[Set[str]] = None, + email_address: Optional[str] = None, + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, + has_unique_identifier: Optional[bool] = None, + # End Old Interface fields + x509_object: Optional[x509.CertificateSigningRequest] = None, + ): + """Initialize the CertificateSigningRequest object. + + This initializer must maintain the old interface while also allowing + instantiation from an existing x509_object. It ignores all fields + other than raw and x509_object, preferring x509_object. + """ + if x509_object: + self._csr = x509_object + return + elif raw: + try: + self._csr = x509.load_pem_x509_csr(raw.encode()) + except ValueError as e: + logger.error("Could not load CSR: %s", e) + raise TLSCertificatesError("Could not load CSR") + return + raise ValueError("Either raw CSR string or x509_object must be provided") + + @property + def common_name(self) -> str: + """Return the common name of the CSR.""" + common_name = self._csr.subject.get_attributes_for_oid(NameOID.COMMON_NAME) + return str(common_name[0].value) if common_name else "" + + @property + def sans_dns(self) -> Set[str]: + """Return the DNS Subject Alternative Names of the CSR.""" + with suppress(x509.ExtensionNotFound): + sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san) for san in sans.get_values_for_type(x509.DNSName)} + return set() + + @property + def sans_ip(self) -> Set[str]: + """Return the IP Subject Alternative Names of the CSR.""" + with suppress(x509.ExtensionNotFound): + sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san) for san in sans.get_values_for_type(x509.IPAddress)} + return set() + + @property + def sans_oid(self) -> Set[str]: + """Return the OID Subject Alternative Names of the CSR.""" + with suppress(x509.ExtensionNotFound): + sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + return {str(san.dotted_string) for san in sans.get_values_for_type(x509.RegisteredID)} + return set() + + @property + def email_address(self) -> Optional[str]: + """Return the email address of the CSR.""" + email_address = self._csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) + return str(email_address[0].value) if email_address else None + + @property + def organization(self) -> Optional[str]: + """Return the organization name of the CSR.""" + organization = self._csr.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) + return str(organization[0].value) if organization else None + + @property + def organizational_unit(self) -> Optional[str]: + """Return the organizational unit name of the CSR.""" + organizational_unit = self._csr.subject.get_attributes_for_oid( + NameOID.ORGANIZATIONAL_UNIT_NAME + ) + return str(organizational_unit[0].value) if organizational_unit else None + + @property + def country_name(self) -> Optional[str]: + """Return the country name of the CSR.""" + country_name = self._csr.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) + return str(country_name[0].value) if country_name else None + + @property + def state_or_province_name(self) -> Optional[str]: + """Return the state or province name of the CSR.""" + state_or_province_name = self._csr.subject.get_attributes_for_oid( + NameOID.STATE_OR_PROVINCE_NAME + ) + return str(state_or_province_name[0].value) if state_or_province_name else None + + @property + def locality_name(self) -> Optional[str]: + """Return the locality name of the CSR.""" + locality_name = self._csr.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) + return str(locality_name[0].value) if locality_name else None + + @property + def has_unique_identifier(self) -> bool: + """Return whether the CSR has a unique identifier.""" + unique_identifier = self._csr.subject.get_attributes_for_oid( + NameOID.X500_UNIQUE_IDENTIFIER + ) + return bool(unique_identifier) + + @property + def raw(self) -> str: + """Return the PEM-formatted string representation of the CSR.""" + return self.__str__() + + def __str__(self) -> str: + """Return the CSR as a string.""" + return self._csr.public_bytes(serialization.Encoding.PEM).decode().strip() + + @property + def additional_critical_extensions(self) -> List[x509.ExtensionType]: + """Return additional critical extensions present on the CSR (excluding SAN).""" + extensions: List[x509.ExtensionType] = [] + for extension in self._csr.extensions: + if extension.critical and extension.oid != ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + extensions.append(extension.value) + return extensions + + @classmethod + def from_string(cls, csr: str) -> "CertificateSigningRequest": + """Create a CertificateSigningRequest object from a CSR.""" + return cls(raw=csr) + + @classmethod + def from_csr(cls, csr: x509.CertificateSigningRequest) -> "CertificateSigningRequest": + """Create a CertificateSigningRequest object from a CSR.""" + return cls(x509_object=csr) + + def __eq__(self, other: object) -> bool: + """Check if two CertificateSigningRequest objects are equal.""" + if not isinstance(other, CertificateSigningRequest): + return NotImplemented + return self.raw == other.raw + + def __hash__(self): + """Return the hash of the private key.""" + return hash(self.raw) + + def matches_certificate(self, certificate: Certificate) -> bool: + """Check if this CSR matches a given certificate. + + Args: + certificate (Certificate): The certificate to validate against. + + Returns: + bool: True if the CSR matches the certificate, False otherwise. + """ + return self._csr.public_key() == certificate._cert.public_key() + + def matches_private_key(self, key: PrivateKey) -> bool: + """Check if a CSR matches a private key. + + This function only works with RSA keys. + + Args: + key (PrivateKey): Private key + Returns: + bool: True/False depending on whether the CSR matches the private key. + """ + try: + key_object_public_key = key._private_key.public_key() + csr_object_public_key = self._csr.public_key() + if not isinstance(key_object_public_key, rsa.RSAPublicKey): + logger.warning("Key is not an RSA key") + return False + if not isinstance(csr_object_public_key, rsa.RSAPublicKey): + logger.warning("CSR is not an RSA key") + return False + if ( + csr_object_public_key.public_numbers().n + != key_object_public_key.public_numbers().n + ): + logger.warning("Public key numbers between CSR and key do not match") + return False + except ValueError: + logger.warning("Could not load certificate or CSR.") + return False + return True + + def get_sha256_hex(self) -> str: + """Calculate the hash of the provided data and return the hexadecimal representation.""" + digest = hashes.Hash(hashes.SHA256()) + digest.update(self.raw.encode()) + return digest.finalize().hex() + + def sign( + self, ca: Certificate, ca_private_key: PrivateKey, validity: timedelta, is_ca: bool = False + ) -> Certificate: + """Sign this CSR with the given CA and CA private key. + + Args: + ca: The CA certificate. + ca_private_key: The CA private key. + validity: The validity period of the certificate. + is_ca: Whether the generated certificate is a CA certificate. + + Returns: + Certificate: The signed certificate. + """ + return Certificate.generate( + csr=self, + ca=ca, + ca_private_key=ca_private_key, + validity=validity, + is_ca=is_ca, + ) + + @classmethod + def generate( + cls, + attributes: "CertificateRequestAttributes", + private_key: PrivateKey, + ) -> "CertificateSigningRequest": + """Generate a CSR using the supplied attributes and private key. + + Args: + attributes (CertificateRequestAttributes): Certificate request attributes + private_key (PrivateKey): Private key + Returns: + CertificateSigningRequest: CSR + """ + signing_key = private_key._private_key + assert isinstance(signing_key, CertificateIssuerPrivateKeyTypes) + + csr_builder = x509.CertificateSigningRequestBuilder() + if subject_name := _extract_subject_name_attributes(attributes): + csr_builder = csr_builder.subject_name(subject_name) + + _sans: List[x509.GeneralName] = [] + if attributes.sans_oid: + _sans.extend( + [x509.RegisteredID(x509.ObjectIdentifier(san)) for san in attributes.sans_oid] + ) + if attributes.sans_ip: + _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in attributes.sans_ip]) + if attributes.sans_dns: + _sans.extend([x509.DNSName(san) for san in attributes.sans_dns]) + if _sans: + csr_builder = csr_builder.add_extension( + x509.SubjectAlternativeName(set(_sans)), critical=False + ) + if attributes.additional_critical_extensions: + for extension in attributes.additional_critical_extensions: + csr_builder = csr_builder.add_extension(extension, critical=True) + signed_certificate_request = csr_builder.sign(signing_key, hashes.SHA256()) + return cls(x509_object=signed_certificate_request) + + +class CertificateRequestAttributes: + """A representation of the certificate request attributes.""" + + def __init__( + self, + common_name: Optional[str] = None, + sans_dns: Optional[Collection[str]] = None, + sans_ip: Optional[Collection[str]] = None, + sans_oid: Optional[Collection[str]] = None, + email_address: Optional[str] = None, + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, + is_ca: bool = False, + add_unique_id_to_subject_name: bool = True, + additional_critical_extensions: Optional[Collection[x509.ExtensionType]] = None, + ): + if not common_name and not sans_dns and not sans_ip and not sans_oid: + raise ValueError( + "At least one of common_name, sans_dns, sans_ip, or sans_oid must be provided" + ) + self._common_name = common_name + self._sans_dns = set(sans_dns) if sans_dns else None + self._sans_ip = set(sans_ip) if sans_ip else None + self._sans_oid = set(sans_oid) if sans_oid else None + self._email_address = email_address + self._organization = organization + self._organizational_unit = organizational_unit + self._country_name = country_name + self._state_or_province_name = state_or_province_name + self._locality_name = locality_name + self._is_ca = is_ca + self._add_unique_id_to_subject_name = add_unique_id_to_subject_name + self._additional_critical_extensions = list(additional_critical_extensions or []) + + @property + def common_name(self) -> str: + """Return the common name.""" + # For legacy interface compatibility, return empty string if not set + return self._common_name if self._common_name else "" + + @property + def sans_dns(self) -> Optional[Set[str]]: + """Return the DNS Subject Alternative Names.""" + return self._sans_dns + + @property + def sans_ip(self) -> Optional[Set[str]]: + """Return the IP Subject Alternative Names.""" + return self._sans_ip + + @property + def sans_oid(self) -> Optional[Set[str]]: + """Return the OID Subject Alternative Names.""" + return self._sans_oid + + @property + def email_address(self) -> Optional[str]: + """Return the email address.""" + return self._email_address + + @property + def organization(self) -> Optional[str]: + """Return the organization name.""" + return self._organization + + @property + def organizational_unit(self) -> Optional[str]: + """Return the organizational unit name.""" + return self._organizational_unit + + @property + def country_name(self) -> Optional[str]: + """Return the country name.""" + return self._country_name + + @property + def state_or_province_name(self) -> Optional[str]: + """Return the state or province name.""" + return self._state_or_province_name + + @property + def locality_name(self) -> Optional[str]: + """Return the locality name.""" + return self._locality_name + + @property + def is_ca(self) -> bool: + """Return whether the certificate is a CA certificate.""" + return self._is_ca + + @property + def add_unique_id_to_subject_name(self) -> bool: + """Return whether to add a unique identifier to the subject name.""" + return self._add_unique_id_to_subject_name + + @property + def additional_critical_extensions(self) -> List[x509.ExtensionType]: + """Return additional critical extensions to be added to the CSR.""" + return self._additional_critical_extensions + + @classmethod + def from_csr( + cls, csr: CertificateSigningRequest, is_ca: bool + ) -> "CertificateRequestAttributes": + """Create CertificateRequestAttributes from a CertificateSigningRequest. + + Args: + csr: The CSR to extract attributes from. + is_ca: Whether a CA certificate is being requested. + + Returns: + CertificateRequestAttributes: The extracted attributes. + """ + return cls( + common_name=csr.common_name, + sans_dns=csr.sans_dns, + sans_ip=csr.sans_ip, + sans_oid=csr.sans_oid, + email_address=csr.email_address, + organization=csr.organization, + organizational_unit=csr.organizational_unit, + country_name=csr.country_name, + state_or_province_name=csr.state_or_province_name, + locality_name=csr.locality_name, + is_ca=is_ca, + add_unique_id_to_subject_name=csr.has_unique_identifier, + additional_critical_extensions=csr.additional_critical_extensions, + ) + + def __eq__(self, other: object) -> bool: + """Check if two CertificateRequestAttributes objects are equal.""" + if not isinstance(other, CertificateRequestAttributes): + return NotImplemented + return ( + self.common_name == other.common_name + and self.sans_dns == other.sans_dns + and self.sans_ip == other.sans_ip + and self.sans_oid == other.sans_oid + and self.email_address == other.email_address + and self.organization == other.organization + and self.organizational_unit == other.organizational_unit + and self.country_name == other.country_name + and self.state_or_province_name == other.state_or_province_name + and self.locality_name == other.locality_name + and self.is_ca == other.is_ca + and self.add_unique_id_to_subject_name == other.add_unique_id_to_subject_name + and self.additional_critical_extensions == other.additional_critical_extensions + ) + + def is_valid(self) -> bool: + """Validate the attributes of the certificate request. + + Returns: + bool: True if the attributes are valid, False otherwise. + """ + if not self.common_name and not self.sans_dns and not self.sans_ip and not self.sans_oid: + logger.warning( + "At least one of common_name, sans_dns, sans_ip, or sans_oid must be provided" + ) + return False + return True + + def generate_csr( + self, + private_key: PrivateKey, + ) -> CertificateSigningRequest: + """Generate a CSR using the current attributes and a private key. + + Args: + private_key (PrivateKey): Private key to sign the CSR. + + Returns: + CertificateSigningRequest: The generated CSR. + """ + return CertificateSigningRequest.generate(self, private_key) + + +@dataclass(frozen=True) +class ProviderCertificate: + """This class represents a certificate provided by the TLS provider.""" + + relation_id: int + certificate: Certificate + certificate_signing_request: CertificateSigningRequest + ca: Certificate + chain: List[Certificate] + revoked: Optional[bool] = None + + def to_json(self) -> str: + """Return the object as a JSON string. + + Returns: + str: JSON representation of the object + """ + return json.dumps( + { + "csr": str(self.certificate_signing_request), + "certificate": str(self.certificate), + "ca": str(self.ca), + "chain": [str(cert) for cert in self.chain], + "revoked": self.revoked, + } + ) + + +@dataclass(frozen=True) +class RequirerCertificateRequest: + """This class represents a certificate signing request requested by a specific TLS requirer.""" + + relation_id: int + certificate_signing_request: CertificateSigningRequest + is_ca: bool + + +class CertificateAvailableEvent(EventBase): + """Charm Event triggered when a TLS certificate is available.""" + + def __init__( + self, + handle: Handle, + certificate: Certificate, + certificate_signing_request: CertificateSigningRequest, + ca: Certificate, + chain: List[Certificate], + ): + super().__init__(handle) + self.certificate = certificate + self.certificate_signing_request = certificate_signing_request + self.ca = ca + self.chain = chain + + def snapshot(self) -> dict: + """Return snapshot.""" + return { + "certificate": str(self.certificate), + "certificate_signing_request": str(self.certificate_signing_request), + "ca": str(self.ca), + "chain": json.dumps([str(certificate) for certificate in self.chain]), + } + + def restore(self, snapshot: dict): + """Restore snapshot.""" + self.certificate = Certificate.from_string(snapshot["certificate"]) + self.certificate_signing_request = CertificateSigningRequest.from_string( + snapshot["certificate_signing_request"] + ) + self.ca = Certificate.from_string(snapshot["ca"]) + chain_strs = json.loads(snapshot["chain"]) + self.chain = [Certificate.from_string(chain_str) for chain_str in chain_strs] + + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join([str(cert) for cert in self.chain]) + + +def generate_private_key( + key_size: int = 2048, + public_exponent: int = 65537, +) -> PrivateKey: + """Generate a private key with the RSA algorithm. + + Args: + key_size (int): Key size in bits, must be at least 2048 bits + public_exponent: Public exponent. + + Returns: + PrivateKey: Private Key + """ + warnings.warn( + "generate_private_key() is deprecated. Use PrivateKey.generate() instead.", + DeprecationWarning, + ) + return PrivateKey.generate(key_size=key_size, public_exponent=public_exponent) + + +def calculate_relative_datetime(target_time: datetime, fraction: float) -> datetime: + """Calculate a datetime that is a given percentage from now to a target time. + + Args: + target_time (datetime): The future datetime to interpolate towards. + fraction (float): Fraction of the interval from now to target_time (0.0-1.0). + 1.0 means return target_time, + 0.9 means return the time after 90% of the interval has passed, + and 0.0 means return now. + """ + if fraction <= 0.0 or fraction > 1.0: + raise ValueError("Invalid fraction. Must be between 0.0 and 1.0") + now = datetime.now(timezone.utc) + time_until_target = target_time - now + return now + time_until_target * fraction + + +def chain_has_valid_order(chain: List[str]) -> bool: + """Check if the chain has a valid order. + + Validates that each certificate in the chain is properly signed by the next certificate. + The chain should be ordered from leaf to root, where each certificate is signed by + the next one in the chain. + + Args: + chain (List[str]): List of certificates in PEM format, ordered from leaf to root + + Returns: + bool: True if the chain has a valid order, False otherwise. + """ + if len(chain) < 2: + return True + + try: + for i in range(len(chain) - 1): + cert = x509.load_pem_x509_certificate(chain[i].encode()) + issuer = x509.load_pem_x509_certificate(chain[i + 1].encode()) + cert.verify_directly_issued_by(issuer) + return True + except (ValueError, TypeError, InvalidSignature): + return False + + +def generate_csr( # noqa: C901 + private_key: PrivateKey, + common_name: str, + sans_dns: Optional[FrozenSet[str]] = frozenset(), + sans_ip: Optional[FrozenSet[str]] = frozenset(), + sans_oid: Optional[FrozenSet[str]] = frozenset(), + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + email_address: Optional[str] = None, + country_name: Optional[str] = None, + locality_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + add_unique_id_to_subject_name: bool = True, +) -> CertificateSigningRequest: + """Generate a CSR using private key and subject. + + Args: + private_key (PrivateKey): Private key + common_name (str): Common name + sans_dns (FrozenSet[str]): DNS Subject Alternative Names + sans_ip (FrozenSet[str]): IP Subject Alternative Names + sans_oid (FrozenSet[str]): OID Subject Alternative Names + organization (Optional[str]): Organization name + organizational_unit (Optional[str]): Organizational unit name + email_address (Optional[str]): Email address + country_name (Optional[str]): Country name + state_or_province_name (Optional[str]): State or province name + locality_name (Optional[str]): Locality name + add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's + subject name. Always leave to "True" when the CSR is used to request certificates + using the tls-certificates relation. + + Returns: + CertificateSigningRequest: CSR + """ + warnings.warn( + "generate_csr() is deprecated. Use CertificateRequestAttributes.generate_csr() or CertificateSigningRequest.generate() instead.", + DeprecationWarning, + ) + return CertificateRequestAttributes( + common_name=common_name, + sans_dns=sans_dns, + sans_ip=sans_ip, + sans_oid=sans_oid, + organization=organization, + organizational_unit=organizational_unit, + email_address=email_address, + country_name=country_name, + state_or_province_name=state_or_province_name, + locality_name=locality_name, + add_unique_id_to_subject_name=add_unique_id_to_subject_name, + ).generate_csr(private_key=private_key) + + +def generate_ca( + private_key: PrivateKey, + validity: timedelta, + common_name: str, + sans_dns: Optional[FrozenSet[str]] = frozenset(), + sans_ip: Optional[FrozenSet[str]] = frozenset(), + sans_oid: Optional[FrozenSet[str]] = frozenset(), + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + email_address: Optional[str] = None, + country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, +) -> Certificate: + """Generate a self signed CA Certificate. + + Args: + private_key: Private key + validity: Certificate validity time + common_name: Common Name that can be an IP or a Full Qualified Domain Name (FQDN). + sans_dns: DNS Subject Alternative Names + sans_ip: IP Subject Alternative Names + sans_oid: OID Subject Alternative Names + organization: Organization name + organizational_unit: Organizational unit name + email_address: Email address + country_name: Certificate Issuing country + state_or_province_name: Certificate Issuing state or province + locality_name: Certificate Issuing locality + + Returns: + CA Certificate. + """ + warnings.warn( + "generate_ca() is deprecated. Use Certificate.generate_self_signed_ca() instead.", + DeprecationWarning, + ) + attributes = CertificateRequestAttributes( + common_name=common_name, + sans_dns=sans_dns, + sans_ip=sans_ip, + sans_oid=sans_oid, + organization=organization, + organizational_unit=organizational_unit, + email_address=email_address, + country_name=country_name, + state_or_province_name=state_or_province_name, + locality_name=locality_name, + is_ca=True, + ) + return Certificate.generate_self_signed_ca(attributes, private_key, validity) + + +def _san_extension( + email_address: Optional[str] = None, + sans_dns: Optional[Collection[str]] = frozenset(), + sans_ip: Optional[Collection[str]] = frozenset(), + sans_oid: Optional[Collection[str]] = frozenset(), +) -> Optional[x509.SubjectAlternativeName]: + sans: List[x509.GeneralName] = [] + if email_address: + # If an e-mail address was provided, it should always be in the SAN + sans.append(x509.RFC822Name(email_address)) + if sans_dns: + sans.extend([x509.DNSName(san) for san in sans_dns]) + if sans_ip: + sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) + if sans_oid: + sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) + if not sans: + return None + return x509.SubjectAlternativeName(sans) + + +def generate_certificate( + csr: CertificateSigningRequest, + ca: Certificate, + ca_private_key: PrivateKey, + validity: timedelta, + is_ca: bool = False, +) -> Certificate: + """Generate a TLS certificate based on a CSR. + + Args: + csr (CertificateSigningRequest): CSR + ca (Certificate): CA Certificate + ca_private_key (PrivateKey): CA private key + validity (timedelta): Certificate validity time + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + Certificate: Certificate + """ + warnings.warn( + "generate_certificate() is deprecated. Use Certificate.generate() instead.", + DeprecationWarning, + ) + return Certificate.generate( + csr=csr, + ca=ca, + ca_private_key=ca_private_key, + validity=validity, + is_ca=is_ca, + ) + + +def _extract_subject_name_attributes( + attributes: CertificateRequestAttributes, +) -> Optional[x509.Name]: + subject_name_attributes = [] + if attributes.common_name: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.COMMON_NAME, attributes.common_name) + ) + if attributes.add_unique_id_to_subject_name: + unique_identifier = uuid.uuid4() + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier)) + ) + if attributes.organization: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, attributes.organization) + ) + if attributes.organizational_unit: + subject_name_attributes.append( + x509.NameAttribute( + x509.NameOID.ORGANIZATIONAL_UNIT_NAME, + attributes.organizational_unit, + ) + ) + if attributes.email_address: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, attributes.email_address) + ) + if attributes.country_name: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.COUNTRY_NAME, attributes.country_name) + ) + if attributes.state_or_province_name: + subject_name_attributes.append( + x509.NameAttribute( + x509.NameOID.STATE_OR_PROVINCE_NAME, + attributes.state_or_province_name, + ) + ) + if attributes.locality_name: + subject_name_attributes.append( + x509.NameAttribute(x509.NameOID.LOCALITY_NAME, attributes.locality_name) + ) + + if subject_name_attributes: + return x509.Name(subject_name_attributes) + + return None + + +def _generate_certificate_request_extensions( + authority_key_identifier: bytes, + csr: x509.CertificateSigningRequest, + is_ca: bool, +) -> List[x509.Extension]: + """Generate a list of certificate extensions from a CSR and other known information. + + Args: + authority_key_identifier (bytes): Authority key identifier + csr (x509.CertificateSigningRequest): CSR + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + List[x509.Extension]: List of extensions + """ + cert_extensions_list: List[x509.Extension] = [ + x509.Extension( + oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, + value=x509.AuthorityKeyIdentifier( + key_identifier=authority_key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, + value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.BASIC_CONSTRAINTS, + critical=True, + value=x509.BasicConstraints(ca=is_ca, path_length=None), + ), + ] + if sans := _generate_subject_alternative_name_extension(csr): + cert_extensions_list.append(sans) + + if is_ca: + cert_extensions_list.append( + x509.Extension( + ExtensionOID.KEY_USAGE, + critical=True, + value=x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + ) + ) + + existing_oids = {ext.oid for ext in cert_extensions_list} + for extension in csr.extensions: + if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + continue + if extension.oid in existing_oids: + logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) + continue + cert_extensions_list.append(extension) + + return cert_extensions_list + + +def _generate_subject_alternative_name_extension( + csr: x509.CertificateSigningRequest, +) -> Optional[x509.Extension]: + sans: List[x509.GeneralName] = [] + try: + loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) + sans.extend( + [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] + ) + sans.extend( + [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] + ) + sans.extend( + [ + x509.RegisteredID(oid) + for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) + ] + ) + sans.extend( + [ + x509.RFC822Name(name) + for name in loaded_san_ext.value.get_values_for_type(x509.RFC822Name) + ] + ) + except x509.ExtensionNotFound: + pass + # If email is present in the CSR Subject, make sure it is also in the SANS + # to conform to RFC 5280. + email = csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) + if email: + email_rfc822 = x509.RFC822Name(str(email[0].value)) + if email_rfc822 not in sans: + sans.append(email_rfc822) + + return ( + x509.Extension( + oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + critical=False, + value=x509.SubjectAlternativeName(sans), + ) + if sans + else None + ) + + +class CertificatesRequirerCharmEvents(CharmEvents): + """List of events that the TLS Certificates requirer charm can leverage.""" + + certificate_available = EventSource(CertificateAvailableEvent) + + +class TLSCertificatesRequiresV4(Object): + """A class to manage the TLS certificates interface for a unit or app.""" + + on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] + + def __init__( + self, + charm: CharmBase, + relationship_name: str, + certificate_requests: List[CertificateRequestAttributes], + mode: Mode = Mode.UNIT, + refresh_events: List[BoundEvent] = [], + private_key: Optional[PrivateKey] = None, + renewal_relative_time: float = 0.9, + ): + """Create a new instance of the TLSCertificatesRequiresV4 class. + + Args: + charm (CharmBase): The charm instance to relate to. + relationship_name (str): The name of the relation that provides the certificates. + certificate_requests (List[CertificateRequestAttributes]): + A list with the attributes of the certificate requests. + mode (Mode): Whether to use unit or app certificates mode. Default is Mode.UNIT. + In UNIT mode the requirer will place the csr in the unit relation data. + Each unit will manage its private key, + certificate signing request and certificate. + UNIT mode is for use cases where each unit has its own identity. + If you don't know which mode to use, you likely need UNIT. + In APP mode the leader unit will place the csr in the app relation databag. + APP mode is for use cases where the underlying application needs the certificate + for example using it as an intermediate CA to sign other certificates. + The certificate can only be accessed by the leader unit. + refresh_events (List[BoundEvent]): A list of events to trigger a refresh of + the certificates. + private_key (Optional[PrivateKey]): The private key to use for the certificates. + If provided, it will be used instead of generating a new one. + If the key is not valid an exception will be raised. + Using this parameter is discouraged, + having to pass around private keys manually can be a security concern. + Allowing the library to generate and manage the key is the more secure approach. + renewal_relative_time (float): The time to renew the certificate relative to its + expiry. + Default is 0.9, meaning 90% of the validity period. + The minimum value is 0.5, meaning 50% of the validity period. + If an invalid value is provided, an exception will be raised. + """ + super().__init__(charm, relationship_name) + if not JujuVersion.from_environ().has_secrets: + logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") + if not self._mode_is_valid(mode): + raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP") + for certificate_request in certificate_requests: + if not certificate_request.is_valid(): + raise TLSCertificatesError("Invalid certificate request") + self.charm = charm + self.relationship_name = relationship_name + self.certificate_requests = certificate_requests + self.mode = mode + if private_key and not private_key.is_valid(): + raise TLSCertificatesError("Invalid private key") + if renewal_relative_time <= 0.5 or renewal_relative_time > 1.0: + raise TLSCertificatesError( + "Invalid renewal relative time. Must be between 0.5 and 1.0" + ) + self._private_key = private_key + self.renewal_relative_time = renewal_relative_time + self.framework.observe(charm.on[relationship_name].relation_created, self._configure) + self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) + self.framework.observe(charm.on.secret_remove, self._on_secret_remove) + for event in refresh_events: + self.framework.observe(event, self._configure) + self._security_logger = _OWASPLogger(application=f"tls-certificates-{charm.app.name}") + + def _configure(self, _: Optional[EventBase] = None): + """Handle TLS Certificates Relation Data. + + This method is called during any TLS relation event. + It will generate a private key if it doesn't exist yet. + It will send certificate requests if they haven't been sent yet. + It will find available certificates and emit events. + """ + if not self._tls_relation_created(): + logger.debug("TLS relation not created yet.") + return + self._ensure_private_key() + self._cleanup_certificate_requests() + self._send_certificate_requests() + self._find_available_certificates() + + def _mode_is_valid(self, mode: Mode) -> bool: + return mode in [Mode.UNIT, Mode.APP] + + def _validate_secret_exists(self, secret: Secret) -> None: + secret.get_info() # Will raise `SecretNotFoundError` if the secret does not exist + + def _on_secret_remove(self, event: SecretRemoveEvent) -> None: + """Handle Secret Removed Event.""" + try: + # Ensure the secret exists before trying to remove it, otherwise + # the unit could be stuck in an error state. See the docstring of + # `remove_revision` and the below issue for more information. + # https://github.com/juju/juju/issues/19036 + self._validate_secret_exists(event.secret) + event.secret.remove_revision(event.revision) + except SecretNotFoundError: + logger.warning( + "No such secret %s, nothing to remove", + event.secret.label or event.secret.id, + ) + return + + def _on_secret_expired(self, event: SecretExpiredEvent) -> None: + """Handle Secret Expired Event. + + Renews certificate requests and removes the expired secret. + """ + if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-certificate"): + return + try: + csr_str = event.secret.get_content(refresh=True)["csr"] + except ModelError: + logger.error("Failed to get CSR from secret - Skipping") + return + csr = CertificateSigningRequest.from_string(csr_str) + self._renew_certificate_request(csr) + event.secret.remove_all_revisions() + + def sync(self) -> None: + """Sync TLS Certificates Relation Data. + + This method allows the requirer to sync the TLS certificates relation data + without waiting for the refresh events to be triggered. + """ + self._configure() + + def renew_certificate(self, certificate: ProviderCertificate) -> None: + """Request the renewal of the provided certificate.""" + certificate_signing_request = certificate.certificate_signing_request + secret_label = self._get_csr_secret_label(certificate_signing_request) + try: + secret = self.model.get_secret(label=secret_label) + except SecretNotFoundError: + logger.warning("No matching secret found - Skipping renewal") + return + current_csr = secret.get_content(refresh=True).get("csr", "") + if current_csr != str(certificate_signing_request): + logger.warning("No matching CSR found - Skipping renewal") + return + self._renew_certificate_request(certificate_signing_request) + secret.remove_all_revisions() + + def _renew_certificate_request(self, csr: CertificateSigningRequest): + """Remove existing CSR from relation data and create a new one.""" + self._remove_requirer_csr_from_relation_data(csr) + self._send_certificate_requests() + logger.info("Renewed certificate request") + + def _remove_requirer_csr_from_relation_data(self, csr: CertificateSigningRequest) -> None: + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return + if not self.get_csrs_from_requirer_relation_data(): + logger.info("No CSRs in relation data - Doing nothing") + return + app_or_unit = self._get_app_or_unit() + try: + requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) + except DataValidationError: + logger.warning("Invalid relation data - Skipping removal of CSR") + return + new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests) + for requirer_csr in new_relation_data: + if requirer_csr.certificate_signing_request.strip() == str(csr).strip(): + new_relation_data.remove(requirer_csr) + try: + _RequirerData(certificate_signing_requests=new_relation_data).dump( + relation.data[app_or_unit] + ) + logger.info("Removed CSR from relation data") + except ModelError: + logger.warning("Failed to update relation data") + + def _get_app_or_unit(self) -> Union[Application, Unit]: + """Return the unit or app object based on the mode.""" + if self.mode == Mode.UNIT: + return self.model.unit + elif self.mode == Mode.APP: + return self.model.app + raise TLSCertificatesError("Invalid mode") + + @property + def private_key(self) -> Optional[PrivateKey]: + """Return the private key.""" + if self._private_key: + return self._private_key + if not self._private_key_generated(): + return None + secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) + private_key = secret.get_content(refresh=True)["private-key"] + return PrivateKey.from_string(private_key) + + def _ensure_private_key(self) -> None: + """Make sure there is a private key to be used. + + It will make sure there is a private key passed by the charm using the private_key + parameter or generate a new one otherwise. + """ + # Remove the generated private key + # if one has been passed by the charm using the private_key parameter + if self._private_key: + self._remove_private_key_secret() + return + if self._private_key_generated(): + logger.debug("Private key already generated") + return + self._generate_private_key() + + def regenerate_private_key(self) -> None: + """Regenerate the private key. + + Generate a new private key, remove old certificate requests and send new ones. + + Raises: + TLSCertificatesError: If the private key is passed by the charm using the + private_key parameter. + """ + if self._private_key: + raise TLSCertificatesError( + "Private key is passed by the charm through the private_key parameter, this function can't be used" + ) + if not self._private_key_generated(): + logger.warning("No private key to regenerate") + return + self._generate_private_key() + self._cleanup_certificate_requests() + self._send_certificate_requests() + + def _generate_private_key(self) -> None: + """Generate a new private key and store it in a secret. + + This is the case when the private key used is generated by the library. + and not passed by the charm using the private_key parameter. + """ + self._store_private_key_in_secret(generate_private_key()) + logger.info("Private key generated") + + def _private_key_generated(self) -> bool: + """Check if a private key is stored in a secret. + + This is the case when the private key used is generated by the library. + This should not exist when the private key used + is passed by the charm using the private_key parameter. + """ + try: + secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) + secret.get_content(refresh=True) + return True + except SecretNotFoundError: + return False + + def _store_private_key_in_secret(self, private_key: PrivateKey) -> None: + try: + secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) + secret.set_content({"private-key": str(private_key)}) + secret.get_content(refresh=True) + except SecretNotFoundError: + self.charm.unit.add_secret( + content={"private-key": str(private_key)}, + label=self._get_private_key_secret_label(), + ) + + def _remove_private_key_secret(self) -> None: + """Remove the private key secret.""" + try: + secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) + secret.remove_all_revisions() + except SecretNotFoundError: + logger.warning("Private key secret not found, nothing to remove") + + def _csr_matches_certificate_request( + self, certificate_signing_request: CertificateSigningRequest, is_ca: bool + ) -> bool: + for certificate_request in self.certificate_requests: + if certificate_request == CertificateRequestAttributes.from_csr( + certificate_signing_request, + is_ca, + ): + return True + return False + + def _certificate_requested(self, certificate_request: CertificateRequestAttributes) -> bool: + if not self.private_key: + return False + csr = self._certificate_requested_for_attributes(certificate_request) + if not csr: + return False + if not csr.certificate_signing_request.matches_private_key(key=self.private_key): + return False + return True + + def _certificate_requested_for_attributes( + self, + certificate_request: CertificateRequestAttributes, + ) -> Optional[RequirerCertificateRequest]: + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if certificate_request == CertificateRequestAttributes.from_csr( + requirer_csr.certificate_signing_request, + requirer_csr.is_ca, + ): + return requirer_csr + return None + + def get_csrs_from_requirer_relation_data(self) -> List[RequirerCertificateRequest]: + """Return list of requirer's CSRs from relation data.""" + if self.mode == Mode.APP and not self.model.unit.is_leader(): + logger.debug("Not a leader unit - Skipping") + return [] + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return [] + app_or_unit = self._get_app_or_unit() + try: + requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) + except DataValidationError: + logger.warning("Invalid relation data") + return [] + requirer_csrs = [] + for csr in requirer_relation_data.certificate_signing_requests: + requirer_csrs.append( + RequirerCertificateRequest( + relation_id=relation.id, + certificate_signing_request=CertificateSigningRequest.from_string( + csr.certificate_signing_request + ), + is_ca=csr.ca if csr.ca else False, + ) + ) + return requirer_csrs + + def get_provider_certificates(self) -> List[ProviderCertificate]: + """Return list of certificates from the provider's relation data.""" + return self._load_provider_certificates() + + def _load_provider_certificates(self) -> List[ProviderCertificate]: + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return [] + if not relation.app: + logger.debug("No remote app in relation: %s", self.relationship_name) + return [] + try: + provider_relation_data = _ProviderApplicationData.load(relation.data[relation.app]) + except DataValidationError: + logger.warning("Invalid relation data") + return [] + return [ + certificate.to_provider_certificate(relation_id=relation.id) + for certificate in provider_relation_data.certificates + ] + + def _request_certificate(self, csr: CertificateSigningRequest, is_ca: bool) -> None: + """Add CSR to relation data.""" + if self.mode == Mode.APP and not self.model.unit.is_leader(): + logger.debug("Not a leader unit - Skipping") + return + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return + new_csr = _CertificateSigningRequest( + certificate_signing_request=str(csr).strip(), ca=is_ca + ) + app_or_unit = self._get_app_or_unit() + try: + requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) + except DataValidationError: + requirer_relation_data = _RequirerData( + certificate_signing_requests=[], + ) + new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests) + new_relation_data.append(new_csr) + try: + _RequirerData(certificate_signing_requests=new_relation_data).dump( + relation.data[app_or_unit] + ) + logger.info("Certificate signing request added to relation data.") + except ModelError: + logger.warning("Failed to update relation data") + + def _send_certificate_requests(self): + if not self.private_key: + logger.debug("Private key not generated yet.") + return + for certificate_request in self.certificate_requests: + if not self._certificate_requested(certificate_request): + csr = certificate_request.generate_csr( + private_key=self.private_key, + ) + if not csr: + logger.warning("Failed to generate CSR") + continue + self._request_certificate(csr=csr, is_ca=certificate_request.is_ca) + + def get_assigned_certificate( + self, certificate_request: CertificateRequestAttributes + ) -> Tuple[Optional[ProviderCertificate], Optional[PrivateKey]]: + """Get the certificate that was assigned to the given certificate request.""" + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if certificate_request == CertificateRequestAttributes.from_csr( + requirer_csr.certificate_signing_request, + requirer_csr.is_ca, + ): + return self._find_certificate_in_relation_data(requirer_csr), self.private_key + return None, None + + def get_assigned_certificates( + self, + ) -> Tuple[List[ProviderCertificate], Optional[PrivateKey]]: + """Get a list of certificates that were assigned to this or app.""" + assigned_certificates = [] + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if cert := self._find_certificate_in_relation_data(requirer_csr): + assigned_certificates.append(cert) + return assigned_certificates, self.private_key + + def _find_certificate_in_relation_data( + self, csr: RequirerCertificateRequest + ) -> Optional[ProviderCertificate]: + """Return the certificate that matches the given CSR, validated against the private key.""" + if not self.private_key: + return None + for provider_certificate in self.get_provider_certificates(): + if provider_certificate.certificate_signing_request == csr.certificate_signing_request: + if provider_certificate.certificate.is_ca and not csr.is_ca: + logger.warning("Non CA certificate requested, got a CA certificate, ignoring") + continue + elif not provider_certificate.certificate.is_ca and csr.is_ca: + logger.warning("CA certificate requested, got a non CA certificate, ignoring") + continue + if not provider_certificate.certificate.matches_private_key(self.private_key): + logger.warning( + "Certificate does not match the private key. Ignoring invalid certificate." + ) + continue + return provider_certificate + return None + + def _find_available_certificates(self): + """Find available certificates and emit events. + + This method will find certificates that are available for the requirer's CSRs. + If a certificate is found, it will be set as a secret and an event will be emitted. + If a certificate is revoked, the secret will be removed and an event will be emitted. + """ + requirer_csrs = self.get_csrs_from_requirer_relation_data() + csrs = [csr.certificate_signing_request for csr in requirer_csrs] + provider_certificates = self.get_provider_certificates() + for provider_certificate in provider_certificates: + if provider_certificate.certificate_signing_request in csrs: + secret_label = self._get_csr_secret_label( + provider_certificate.certificate_signing_request + ) + if provider_certificate.revoked: + with suppress(SecretNotFoundError): + logger.debug( + "Removing secret with label %s", + secret_label, + ) + secret = self.model.get_secret(label=secret_label) + secret.remove_all_revisions() + else: + if not self._csr_matches_certificate_request( + certificate_signing_request=provider_certificate.certificate_signing_request, + is_ca=provider_certificate.certificate.is_ca, + ): + logger.debug("Certificate requested for different attributes - Skipping") + continue + try: + secret = self.model.get_secret(label=secret_label) + logger.debug("Setting secret with label %s", secret_label) + # Juju < 3.6 will create a new revision even if the content is the same + if secret.get_content(refresh=True).get("certificate", "") == str( + provider_certificate.certificate + ): + logger.debug( + "Secret %s with correct certificate already exists", secret_label + ) + continue + secret.set_content( + content={ + "certificate": str(provider_certificate.certificate), + "csr": str(provider_certificate.certificate_signing_request), + } + ) + secret.set_info( + expire=calculate_relative_datetime( + target_time=provider_certificate.certificate.expiry_time, + fraction=self.renewal_relative_time, + ), + ) + secret.get_content(refresh=True) + except SecretNotFoundError: + logger.debug("Creating new secret with label %s", secret_label) + secret = self.charm.unit.add_secret( + content={ + "certificate": str(provider_certificate.certificate), + "csr": str(provider_certificate.certificate_signing_request), + }, + label=secret_label, + expire=calculate_relative_datetime( + target_time=provider_certificate.certificate.expiry_time, + fraction=self.renewal_relative_time, + ), + ) + self.on.certificate_available.emit( + certificate_signing_request=provider_certificate.certificate_signing_request, + certificate=provider_certificate.certificate, + ca=provider_certificate.ca, + chain=provider_certificate.chain, + ) + + def _cleanup_certificate_requests(self): + """Clean up certificate requests. + + Remove any certificate requests that falls into one of the following categories: + - The CSR attributes do not match any of the certificate requests defined in + the charm's certificate_requests attribute. + - The CSR public key does not match the private key. + """ + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if not self._csr_matches_certificate_request( + certificate_signing_request=requirer_csr.certificate_signing_request, + is_ca=requirer_csr.is_ca, + ): + self._remove_requirer_csr_from_relation_data( + requirer_csr.certificate_signing_request + ) + logger.info( + "Removed CSR from relation data because it did not match any certificate request" # noqa: E501 + ) + elif ( + self.private_key + and not requirer_csr.certificate_signing_request.matches_private_key( + self.private_key + ) + ): + self._remove_requirer_csr_from_relation_data( + requirer_csr.certificate_signing_request + ) + logger.info( + "Removed CSR from relation data because it did not match the private key" + ) # noqa: E501 + + def _tls_relation_created(self) -> bool: + relation = self.model.get_relation(self.relationship_name) + if not relation: + return False + return True + + def _get_private_key_secret_label(self) -> str: + if self.mode == Mode.UNIT: + return f"{LIBID}-private-key-{self._get_unit_number()}-{self.relationship_name}" + elif self.mode == Mode.APP: + return f"{LIBID}-private-key-{self.relationship_name}" + else: + raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.") + + def _get_csr_secret_label(self, csr: CertificateSigningRequest) -> str: + csr_in_sha256_hex = csr.get_sha256_hex() + if self.mode == Mode.UNIT: + return f"{LIBID}-certificate-{self._get_unit_number()}-{csr_in_sha256_hex}" + elif self.mode == Mode.APP: + return f"{LIBID}-certificate-{csr_in_sha256_hex}" + else: + raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.") + + def _get_unit_number(self) -> str: + return self.model.unit.name.split("/")[1] + + +class TLSCertificatesProvidesV4(Object): + """TLS certificates provider class to be instantiated by TLS certificates providers.""" + + def __init__(self, charm: CharmBase, relationship_name: str): + super().__init__(charm, relationship_name) + self.framework.observe(charm.on[relationship_name].relation_joined, self._configure) + self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) + self.framework.observe(charm.on.update_status, self._configure) + self.charm = charm + self.relationship_name = relationship_name + self._security_logger = _OWASPLogger(application=f"tls-certificates-{charm.app.name}") + + def _configure(self, _: EventBase) -> None: + """Handle update status and tls relation changed events. + + This is a common hook triggered on a regular basis. + + Revoke certificates for which no csr exists + """ + if not self.model.unit.is_leader(): + return + self._remove_certificates_for_which_no_csr_exists() + + def _remove_certificates_for_which_no_csr_exists(self) -> None: + provider_certificates = self.get_provider_certificates() + requirer_csrs = [ + request.certificate_signing_request for request in self.get_certificate_requests() + ] + for provider_certificate in provider_certificates: + if provider_certificate.certificate_signing_request not in requirer_csrs: + tls_relation = self._get_tls_relations( + relation_id=provider_certificate.relation_id + ) + self._remove_provider_certificate( + certificate=provider_certificate.certificate, + relation=tls_relation[0], + ) + + def _get_tls_relations(self, relation_id: Optional[int] = None) -> List[Relation]: + return ( + [ + relation + for relation in self.model.relations[self.relationship_name] + if relation.id == relation_id + ] + if relation_id is not None + else self.model.relations.get(self.relationship_name, []) + ) + + def get_certificate_requests( + self, relation_id: Optional[int] = None + ) -> List[RequirerCertificateRequest]: + """Load certificate requests from the relation data.""" + relations = self._get_tls_relations(relation_id) + requirer_csrs: List[RequirerCertificateRequest] = [] + for relation in relations: + for unit in relation.units: + requirer_csrs.extend(self._load_requirer_databag(relation, unit)) + requirer_csrs.extend(self._load_requirer_databag(relation, relation.app)) + return requirer_csrs + + def _load_requirer_databag( + self, relation: Relation, unit_or_app: Union[Application, Unit] + ) -> List[RequirerCertificateRequest]: + try: + requirer_relation_data = _RequirerData.load(relation.data.get(unit_or_app, {})) + except DataValidationError: + logger.debug("Invalid requirer relation data for %s", unit_or_app.name) + return [] + return [ + RequirerCertificateRequest( + relation_id=relation.id, + certificate_signing_request=CertificateSigningRequest.from_string( + csr.certificate_signing_request + ), + is_ca=csr.ca if csr.ca else False, + ) + for csr in requirer_relation_data.certificate_signing_requests + ] + + def _add_provider_certificate( + self, + relation: Relation, + provider_certificate: ProviderCertificate, + ) -> None: + chain = [str(certificate) for certificate in provider_certificate.chain] + if chain[0] != str(provider_certificate.certificate): + logger.warning( + "The order of the chain from the TLS Certificates Provider is incorrect. " + "The leaf certificate should be the first element of the chain." + ) + elif not chain_has_valid_order(chain): + logger.warning( + "The order of the chain from the TLS Certificates Provider is partially incorrect." + ) + new_certificate = _Certificate( + certificate=str(provider_certificate.certificate), + certificate_signing_request=str(provider_certificate.certificate_signing_request), + ca=str(provider_certificate.ca), + chain=chain, + ) + provider_certificates = self._load_provider_certificates(relation) + if new_certificate in provider_certificates: + logger.info("Certificate already in relation data - Doing nothing") + return + provider_certificates.append(new_certificate) + self._dump_provider_certificates(relation=relation, certificates=provider_certificates) + + def _load_provider_certificates(self, relation: Relation) -> List[_Certificate]: + try: + provider_relation_data = _ProviderApplicationData.load(relation.data[self.charm.app]) + except DataValidationError: + logger.debug("Invalid provider relation data") + return [] + return copy.deepcopy(provider_relation_data.certificates) + + def _dump_provider_certificates(self, relation: Relation, certificates: List[_Certificate]): + try: + _ProviderApplicationData(certificates=certificates).dump(relation.data[self.model.app]) + logger.info("Certificate relation data updated") + except ModelError: + logger.warning("Failed to update relation data") + + def _remove_provider_certificate( + self, + relation: Relation, + certificate: Optional[Certificate] = None, + certificate_signing_request: Optional[CertificateSigningRequest] = None, + ) -> None: + """Remove certificate based on certificate or certificate signing request.""" + provider_certificates = self._load_provider_certificates(relation) + for provider_certificate in provider_certificates: + if certificate and provider_certificate.certificate == str(certificate): + provider_certificates.remove(provider_certificate) + if ( + certificate_signing_request + and provider_certificate.certificate_signing_request + == str(certificate_signing_request) + ): + provider_certificates.remove(provider_certificate) + self._dump_provider_certificates(relation=relation, certificates=provider_certificates) + + def revoke_all_certificates(self) -> None: + """Revoke all certificates of this provider. + + This method is meant to be used when the Root CA has changed. + """ + if not self.model.unit.is_leader(): + logger.warning("Unit is not a leader - will not set relation data") + return + relations = self._get_tls_relations() + for relation in relations: + provider_certificates = self._load_provider_certificates(relation) + for certificate in provider_certificates: + certificate.revoked = True + self._dump_provider_certificates(relation=relation, certificates=provider_certificates) + self._security_logger.log_event( + event="all_certificates_revoked", + level=logging.WARNING, + description="All certificates revoked", + ) + + def set_relation_certificate( + self, + provider_certificate: ProviderCertificate, + ) -> None: + """Add certificates to relation data. + + Args: + provider_certificate (ProviderCertificate): ProviderCertificate object + + Returns: + None + """ + if not self.model.unit.is_leader(): + logger.warning("Unit is not a leader - will not set relation data") + return + certificates_relation = self.model.get_relation( + relation_name=self.relationship_name, relation_id=provider_certificate.relation_id + ) + if not certificates_relation: + raise TLSCertificatesError(f"Relation {self.relationship_name} does not exist") + self._remove_provider_certificate( + relation=certificates_relation, + certificate_signing_request=provider_certificate.certificate_signing_request, + ) + self._add_provider_certificate( + relation=certificates_relation, + provider_certificate=provider_certificate, + ) + self._security_logger.log_event( + event="certificate_provided", + level=logging.INFO, + description="Certificate provided to requirer", + relation_id=str(provider_certificate.relation_id), + common_name=provider_certificate.certificate.common_name, + ) + + def get_issued_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued (non revoked) certificates. + + Returns: + List: List of ProviderCertificate objects + """ + if not self.model.unit.is_leader(): + logger.warning("Unit is not a leader - will not read relation data") + return [] + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + return [certificate for certificate in provider_certificates if not certificate.revoked] + + def get_provider_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued certificates.""" + certificates: List[ProviderCertificate] = [] + relations = self._get_tls_relations(relation_id) + for relation in relations: + if not relation.app: + logger.warning("Relation %s does not have an application", relation.id) + continue + for certificate in self._load_provider_certificates(relation): + certificates.append(certificate.to_provider_certificate(relation_id=relation.id)) + return certificates + + def get_unsolicited_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return provider certificates for which no certificate requests exists. + + Those certificates should be revoked. + """ + unsolicited_certificates: List[ProviderCertificate] = [] + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + requirer_csrs = self.get_certificate_requests(relation_id=relation_id) + list_of_csrs = [csr.certificate_signing_request for csr in requirer_csrs] + for certificate in provider_certificates: + if certificate.certificate_signing_request not in list_of_csrs: + unsolicited_certificates.append(certificate) + return unsolicited_certificates + + def get_outstanding_certificate_requests( + self, relation_id: Optional[int] = None + ) -> List[RequirerCertificateRequest]: + """Return CSR's for which no certificate has been issued. + + Args: + relation_id (int): Relation id + + Returns: + list: List of RequirerCertificateRequest objects. + """ + requirer_csrs = self.get_certificate_requests(relation_id=relation_id) + outstanding_csrs: List[RequirerCertificateRequest] = [] + for relation_csr in requirer_csrs: + if not self._certificate_issued_for_csr( + csr=relation_csr.certificate_signing_request, + relation_id=relation_id, + ): + outstanding_csrs.append(relation_csr) + return outstanding_csrs + + def _certificate_issued_for_csr( + self, csr: CertificateSigningRequest, relation_id: Optional[int] + ) -> bool: + """Check whether a certificate has been issued for a given CSR.""" + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) + for issued_certificate in issued_certificates_per_csr: + if issued_certificate.certificate_signing_request == csr: + return csr.matches_certificate(issued_certificate.certificate) + return False diff --git a/dovecot-charm/templates/dovecot.conf.tmpl b/dovecot-charm/templates/dovecot.conf.tmpl index 6c76a4a..5dc5493 100644 --- a/dovecot-charm/templates/dovecot.conf.tmpl +++ b/dovecot-charm/templates/dovecot.conf.tmpl @@ -10,16 +10,16 @@ auth_verbose = yes auth_verbose_passwords = no # TODO: change to ssl = required once TLS relation is added (pr/5-tls) -ssl = yes -ssl_cert = Date: Fri, 3 Apr 2026 13:38:11 +0300 Subject: [PATCH 02/39] docs: add release notes for pr/3-tls --- docs/release-notes/artifacts/pr-3-tls.yaml | 15 +++++++ docs/release-notes/index.rst | 1 + docs/release-notes/release-notes-0004.rst | 52 ++++++++++++++++++++++ 3 files changed, 68 insertions(+) create mode 100644 docs/release-notes/artifacts/pr-3-tls.yaml create mode 100644 docs/release-notes/release-notes-0004.rst diff --git a/docs/release-notes/artifacts/pr-3-tls.yaml b/docs/release-notes/artifacts/pr-3-tls.yaml new file mode 100644 index 0000000..83efdba --- /dev/null +++ b/docs/release-notes/artifacts/pr-3-tls.yaml @@ -0,0 +1,15 @@ +# Version of the artifact schema +version_schema: 2 + +changes: +- title: Added TLS certificate integration via the certificates relation + author: alithethird + type: major + description: Added TLS support using the tls-certificates-interface library. The charm requests certificates via the certificates relation, writes the cert and key to /etc/dovecot/private/, and restarts Dovecot automatically after installation. + urls: + pr: + - "https://github.com/canonical/mailserver-operators/pull/4" + related_doc: + related_issue: + visibility: public + highlight: true diff --git a/docs/release-notes/index.rst b/docs/release-notes/index.rst index 5c349d3..23c635d 100644 --- a/docs/release-notes/index.rst +++ b/docs/release-notes/index.rst @@ -34,3 +34,4 @@ Releases release-notes-0001 release-notes-0002 release-notes-0003 + release-notes-0004 diff --git a/docs/release-notes/release-notes-0004.rst b/docs/release-notes/release-notes-0004.rst new file mode 100644 index 0000000..ab54e1e --- /dev/null +++ b/docs/release-notes/release-notes-0004.rst @@ -0,0 +1,52 @@ +.. _release_notes_release_notes_0004: + +Dovecot release notes – 2.3/edge +================================= + +These release notes cover new features and changes in Dovecot. + +Main features: + +* Added TLS certificate integration via the ``certificates`` relation. + +See our :ref:`Release policy and schedule `. + +Requirements and compatibility +------------------------------- + +The charm operates Dovecot 2.3. + +.. list-table:: + :header-rows: 1 + :widths: 50 50 + + * - Software + - Required version + * - Juju + - 3.x + * - Ubuntu + - 24.04 + +Updates +------- + +The following major and minor features were added in this release. + +Added TLS certificate integration via the ``certificates`` relation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Added TLS support to the Dovecot charm using the ``tls-certificates-interface`` library. When a ``certificates`` relation is established, the charm requests a certificate for the configured mailname and handles the ``certificate_available`` event by writing the certificate and private key to ``/etc/dovecot/private/``. The Dovecot service is automatically restarted after certificate installation. The Dovecot configuration template was updated to reference the certificate and key paths for IMAPS and POP3S listeners. + +Relevant links: + +* `PR `_ + +Bug fixes +--------- + +No bug fixes in this release. + +Known issues +------------ + +No known issues. From 261d7419c6b3f5be4274c218cf4b78725ff9cff4 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Thu, 9 Apr 2026 14:43:22 +0300 Subject: [PATCH 03/39] Refactor TLS certificate tests and update dependencies - Updated unit tests for DovecotCharm to use new setup methods for Dovecot and Procmail. - Replaced calls to _systemctl with _setup_dovecot and _setup_procmail in certificate tests. - Ensured service reload is properly mocked and verified in tests. - Added new dependency on charmlibs-interfaces-tls-certificates version 1.8.1 in uv.lock. --- dovecot-charm/charmcraft.yaml | 4 - .../v4/tls_certificates.py | 2526 ----------------- dovecot-charm/pyproject.toml | 1 + dovecot-charm/src/charm.py | 8 +- .../v4/tls_certificates.py | 2525 ---------------- dovecot-charm/uv.lock | 16 + 6 files changed, 21 insertions(+), 5059 deletions(-) delete mode 100644 dovecot-charm/lib/charms/tls_certificates_interface/v4/tls_certificates.py delete mode 100644 dovecot-charm/src/charms/tls_certificates_interface/v4/tls_certificates.py diff --git a/dovecot-charm/charmcraft.yaml b/dovecot-charm/charmcraft.yaml index de95635..7ef911e 100644 --- a/dovecot-charm/charmcraft.yaml +++ b/dovecot-charm/charmcraft.yaml @@ -7,10 +7,6 @@ base: ubuntu@24.04 platforms: amd64: -charm-libs: - - lib: tls-certificates-interface.tls_certificates - version: "4" - name: dovecot-charm summary: Dovecot IMAP/POP3 mail server charm diff --git a/dovecot-charm/lib/charms/tls_certificates_interface/v4/tls_certificates.py b/dovecot-charm/lib/charms/tls_certificates_interface/v4/tls_certificates.py deleted file mode 100644 index 32b3b15..0000000 --- a/dovecot-charm/lib/charms/tls_certificates_interface/v4/tls_certificates.py +++ /dev/null @@ -1,2526 +0,0 @@ -# Copyright 2024 Canonical Ltd. -# See LICENSE file for licensing details. - -"""Legacy Charmhub-hosted lib, deprecated in favour of ``charmlibs.interfaces.tls_certificates``. - -WARNING: This library is deprecated. -It will not receive feature updates or bugfixes. -``charmlibs.interfaces.tls_certificates`` 1.0 is a bug-for-bug compatible migration of this library. - -To migrate: -1. Add 'charmlibs-interfaces-tls-certificates~=1.0' to your charm's dependencies, - and remove this Charmhub-hosted library from your charm. -2. You can also remove any dependencies added to your charm only because of this library. -3. Replace `from charms.tls_certificates_interface.v4 import tls_certificates` - with `from charmlibs.interfaces import tls_certificates`. - -Read more: -- https://documentation.ubuntu.com/charmlibs -- https://pypi.org/project/charmlibs-interfaces-tls-certificates - ---- - -Charm library for managing TLS certificates (V4). - -This library contains the Requires and Provides classes for handling the tls-certificates -interface. - -Pre-requisites: - - Juju >= 3.0 - - cryptography >= 43.0.0 - - pydantic >= 1.0 - -Learn more on how-to use the TLS Certificates interface library by reading the documentation: -- https://charmhub.io/tls-certificates-interface/ - -""" # noqa: D214, D405, D411, D416 - -import copy -import ipaddress -import json -import logging -import uuid -import warnings -from contextlib import suppress -from dataclasses import asdict, dataclass, field -from datetime import datetime, timedelta, timezone -from enum import Enum -from typing import ( - Collection, - Dict, - FrozenSet, - List, - MutableMapping, - Optional, - Set, - Tuple, - Union, -) - -import pydantic -from cryptography import x509 -from cryptography.exceptions import InvalidSignature -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.asymmetric.types import CertificateIssuerPrivateKeyTypes -from cryptography.x509.oid import ExtensionOID, NameOID -from ops import BoundEvent, CharmBase, CharmEvents, Secret, SecretExpiredEvent, SecretRemoveEvent -from ops.framework import EventBase, EventSource, Handle, Object -from ops.jujuversion import JujuVersion -from ops.model import Application, ModelError, Relation, SecretNotFoundError, Unit - -# The unique Charmhub library identifier, never change it -LIBID = "afd8c2bccf834997afce12c2706d2ede" - -# Increment this major API version when introducing breaking changes -LIBAPI = 4 - -# Increment this PATCH version before using `charmcraft publish-lib` or reset -# to 0 if you are raising the major API version -LIBPATCH = 27 - -PYDEPS = [ - "cryptography>=43.0.0", - "pydantic", -] -IS_PYDANTIC_V1 = int(pydantic.version.VERSION.split(".")[0]) < 2 - -logger = logging.getLogger(__name__) - -NESTED_JSON_KEY = "owasp_event" - - -@dataclass -class _OWASPLogEvent: - """OWASP-compliant log event.""" - - datetime: str - event: str - level: str - description: str - type: str = "security" - labels: Dict[str, str] = field(default_factory=dict) - - def to_json(self) -> str: - return json.dumps(self.to_dict(), ensure_ascii=False) - - def to_dict(self) -> Dict: - log_event = dict(asdict(self), **self.labels) - log_event.pop("labels", None) - return {k: v for k, v in log_event.items() if v is not None} - - -class _OWASPLogger: - """OWASP-compliant logger for security events.""" - - def __init__(self, application: Optional[str] = None): - self.application = application - self._logger = logging.getLogger(__name__) - - def log_event(self, event: str, level: int, description: str, **labels: str): - if self.application and "application" not in labels: - labels["application"] = self.application - log = _OWASPLogEvent( - datetime=datetime.now(timezone.utc).astimezone().isoformat(), - event=event, - level=logging.getLevelName(level), - description=description, - labels=labels, - ) - self._logger.log(level, log.to_json(), extra={NESTED_JSON_KEY: log.to_dict()}) - - -class TLSCertificatesError(Exception): - """Base class for custom errors raised by this library.""" - - -class DataValidationError(TLSCertificatesError): - """Raised when data validation fails.""" - - -class _DatabagModel(pydantic.BaseModel): - """Base databag model. - - Supports both pydantic v1 and v2. - """ - - if IS_PYDANTIC_V1: - - class Config: - """Pydantic config.""" - - # ignore any extra fields in the databag - extra = "ignore" - """Ignore any extra fields in the databag.""" - allow_population_by_field_name = True - """Allow instantiating this class by field name (instead of forcing alias).""" - - _NEST_UNDER = None - - model_config = pydantic.ConfigDict( - # tolerate additional keys in databag - extra="ignore", - # Allow instantiating this class by field name (instead of forcing alias). - populate_by_name=True, - # Custom config key: whether to nest the whole datastructure (as json) - # under a field or spread it out at the toplevel. - _NEST_UNDER=None, - ) # type: ignore - """Pydantic config.""" - - @classmethod - def load(cls, databag: MutableMapping): - """Load this model from a Juju databag.""" - if IS_PYDANTIC_V1: - return cls._load_v1(databag) - nest_under = cls.model_config.get("_NEST_UNDER") - if nest_under: - return cls.model_validate(json.loads(databag[nest_under])) - - try: - data = { - k: json.loads(v) - for k, v in databag.items() - # Don't attempt to parse model-external values - if k in {(f.alias or n) for n, f in cls.model_fields.items()} - } - except json.JSONDecodeError as e: - msg = f"invalid databag contents: expecting json. {databag}" - logger.error(msg) - raise DataValidationError(msg) from e - - try: - return cls.model_validate_json(json.dumps(data)) - except pydantic.ValidationError as e: - msg = f"failed to validate databag: {databag}" - logger.debug(msg, exc_info=True) - raise DataValidationError(msg) from e - - @classmethod - def _load_v1(cls, databag: MutableMapping): - """Load implementation for pydantic v1.""" - if cls._NEST_UNDER: - return cls.parse_obj(json.loads(databag[cls._NEST_UNDER])) - - try: - data = { - k: json.loads(v) - for k, v in databag.items() - # Don't attempt to parse model-external values - if k in {f.alias for f in cls.__fields__.values()} - } - except json.JSONDecodeError as e: - msg = f"invalid databag contents: expecting json. {databag}" - logger.error(msg) - raise DataValidationError(msg) from e - - try: - return cls.parse_raw(json.dumps(data)) # type: ignore - except pydantic.ValidationError as e: - msg = f"failed to validate databag: {databag}" - logger.debug(msg, exc_info=True) - raise DataValidationError(msg) from e - - def dump(self, databag: Optional[MutableMapping] = None, clear: bool = True): - """Write the contents of this model to Juju databag. - - Args: - databag: The databag to write to. - clear: Whether to clear the databag before writing. - - Returns: - MutableMapping: The databag. - """ - if IS_PYDANTIC_V1: - return self._dump_v1(databag, clear) - if clear and databag: - databag.clear() - - if databag is None: - databag = {} - nest_under = self.model_config.get("_NEST_UNDER") - if nest_under: - databag[nest_under] = self.model_dump_json( - by_alias=True, - # skip keys whose values are default - exclude_defaults=True, - ) - return databag - - dct = self.model_dump(mode="json", by_alias=True, exclude_defaults=True) - databag.update({k: json.dumps(v) for k, v in dct.items()}) - return databag - - def _dump_v1(self, databag: Optional[MutableMapping] = None, clear: bool = True): - """Dump implementation for pydantic v1.""" - if clear and databag: - databag.clear() - - if databag is None: - databag = {} - - if self._NEST_UNDER: - databag[self._NEST_UNDER] = self.json(by_alias=True, exclude_defaults=True) - return databag - - dct = json.loads(self.json(by_alias=True, exclude_defaults=True)) - databag.update({k: json.dumps(v) for k, v in dct.items()}) - - return databag - - -class _Certificate(pydantic.BaseModel): - """Certificate model.""" - - ca: str - certificate_signing_request: str - certificate: str - chain: Optional[List[str]] = None - revoked: Optional[bool] = None - - def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate": - """Convert to a ProviderCertificate.""" - return ProviderCertificate( - relation_id=relation_id, - certificate=Certificate.from_string(self.certificate), - certificate_signing_request=CertificateSigningRequest.from_string( - self.certificate_signing_request - ), - ca=Certificate.from_string(self.ca), - chain=[Certificate.from_string(certificate) for certificate in self.chain] - if self.chain - else [], - revoked=self.revoked, - ) - - -class _CertificateSigningRequest(pydantic.BaseModel): - """Certificate signing request model.""" - - certificate_signing_request: str - ca: Optional[bool] - - -class _ProviderApplicationData(_DatabagModel): - """Provider application data model.""" - - certificates: List[_Certificate] = [] - - -class _RequirerData(_DatabagModel): - """Requirer data model. - - The same model is used for the unit and application data. - """ - - certificate_signing_requests: List[_CertificateSigningRequest] = [] - - -class Mode(Enum): - """Enum representing the mode of the certificate request. - - UNIT (default): Request a certificate for the unit. - Each unit will manage its private key, - certificate signing request and certificate. - APP: Request a certificate for the application. - Only the leader unit will manage the private key, certificate signing request - and certificate. - """ - - UNIT = 1 - APP = 2 - - -class PrivateKey: - """This class represents a private key.""" - - def __init__( - self, raw: Optional[str] = None, x509_object: Optional[rsa.RSAPrivateKey] = None - ) -> None: - """Initialize the PrivateKey object. - - If both raw and x509_object are provided, x509_object takes precedence. - """ - if x509_object: - self._private_key = x509_object - elif raw: - self._private_key = serialization.load_pem_private_key( - raw.encode(), - password=None, - ) - else: - raise ValueError("Either raw private key string or x509_object must be provided") - - @property - def raw(self) -> str: - """Return the PEM-formatted string representation of the private key.""" - return str(self) - - def __str__(self): - """Return the private key as a string in PEM format.""" - return ( - self._private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ) - .decode() - .strip() - ) - - def __hash__(self): - """Return the hash of the private key.""" - return hash(self.raw) - - @classmethod - def from_string(cls, private_key: str) -> "PrivateKey": - """Create a PrivateKey object from a private key.""" - return cls(raw=private_key) - - def is_valid(self) -> bool: - """Validate that the private key is PEM-formatted, RSA, and at least 2048 bits.""" - try: - if not isinstance(self._private_key, rsa.RSAPrivateKey): - logger.warning("Private key is not an RSA key") - return False - - if self._private_key.key_size < 2048: - logger.warning("RSA key size is less than 2048 bits") - return False - - return True - except ValueError: - logger.warning("Invalid private key format") - return False - - @classmethod - def generate(cls, key_size: int = 2048, public_exponent: int = 65537) -> "PrivateKey": - """Generate a new RSA private key. - - Args: - key_size: The size of the key in bits. - public_exponent: The public exponent of the key. - - Returns: - PrivateKey: The generated private key. - """ - private_key = rsa.generate_private_key( - public_exponent=public_exponent, - key_size=key_size, - ) - _OWASPLogger().log_event( - event="private_key_generated", - level=logging.INFO, - description="Private key generated", - key_size=str(key_size), - ) - return PrivateKey(x509_object=private_key) - - def __eq__(self, other: object) -> bool: - """Check if two PrivateKey objects are equal.""" - if not isinstance(other, PrivateKey): - return NotImplemented - return self.raw == other.raw - - -class Certificate: - """This class represents a certificate.""" - - _cert: x509.Certificate - - def __init__( - self, - raw: Optional[str] = None, # Must remain first argument for backwards compatibility - # Old Interface fields (ignored) - common_name: Optional[str] = None, - expiry_time: Optional[datetime] = None, - validity_start_time: Optional[datetime] = None, - is_ca: Optional[bool] = None, - sans_dns: Optional[Set[str]] = None, - sans_ip: Optional[Set[str]] = None, - sans_oid: Optional[Set[str]] = None, - email_address: Optional[str] = None, - organization: Optional[str] = None, - organizational_unit: Optional[str] = None, - country_name: Optional[str] = None, - state_or_province_name: Optional[str] = None, - locality_name: Optional[str] = None, - # End Old Interface fields - x509_object: Optional[x509.Certificate] = None, - ) -> None: - """Initialize the Certificate object. - - This initializer must maintain the old interface while also allowing - instantiation from an existing x509_object. It ignores all fields - other than raw and x509_object, preferring x509_object. - """ - if x509_object: - self._cert = x509_object - elif raw: - self._cert = x509.load_pem_x509_certificate(data=raw.encode()) - else: - raise ValueError("Either raw certificate string or x509_object must be provided") - - @property - def raw(self) -> str: - """Return the PEM-formatted string representation of the certificate.""" - return str(self) - - @property - def common_name(self) -> str: - """Return the common name of the certificate.""" - # We maintain compatibility with the old interface by returning - # an empty string if no common name is set. - common_name = self._cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME) - return str(common_name[0].value) if common_name else "" - - @property - def expiry_time(self) -> datetime: - """Return the expiry time of the certificate.""" - return self._cert.not_valid_after_utc - - @property - def validity_start_time(self) -> datetime: - """Return the validity start time of the certificate.""" - return self._cert.not_valid_before_utc - - @property - def is_ca(self) -> bool: - """Return whether the certificate is a CA certificate.""" - try: - return self._cert.extensions.get_extension_for_oid( - ExtensionOID.BASIC_CONSTRAINTS - ).value.ca # type: ignore[reportAttributeAccessIssue] - except x509.ExtensionNotFound: - return False - - @property - def sans_dns(self) -> Optional[Set[str]]: - """Return the DNS Subject Alternative Names of the certificate.""" - with suppress(x509.ExtensionNotFound): - sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - return {str(san) for san in sans.get_values_for_type(x509.DNSName)} - return None - - @property - def sans_ip(self) -> Optional[Set[str]]: - """Return the IP Subject Alternative Names of the certificate.""" - with suppress(x509.ExtensionNotFound): - sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - return {str(san) for san in sans.get_values_for_type(x509.IPAddress)} - return None - - @property - def sans_oid(self) -> Optional[Set[str]]: - """Return the OID Subject Alternative Names of the certificate.""" - with suppress(x509.ExtensionNotFound): - sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - return {str(san.dotted_string) for san in sans.get_values_for_type(x509.RegisteredID)} - return None - - @property - def email_address(self) -> Optional[str]: - """Return the email address of the certificate.""" - email_address = self._cert.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) - return str(email_address[0].value) if email_address else None - - @property - def organization(self) -> Optional[str]: - """Return the organization name of the certificate.""" - organization = self._cert.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) - return str(organization[0].value) if organization else None - - @property - def organizational_unit(self) -> Optional[str]: - """Return the organizational unit name of the certificate.""" - organizational_unit = self._cert.subject.get_attributes_for_oid( - NameOID.ORGANIZATIONAL_UNIT_NAME - ) - return str(organizational_unit[0].value) if organizational_unit else None - - @property - def country_name(self) -> Optional[str]: - """Return the country name of the certificate.""" - country_name = self._cert.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) - return str(country_name[0].value) if country_name else None - - @property - def state_or_province_name(self) -> Optional[str]: - """Return the state or province name of the certificate.""" - state_or_province_name = self._cert.subject.get_attributes_for_oid( - NameOID.STATE_OR_PROVINCE_NAME - ) - return str(state_or_province_name[0].value) if state_or_province_name else None - - @property - def locality_name(self) -> Optional[str]: - """Return the locality name of the certificate.""" - locality_name = self._cert.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) - return str(locality_name[0].value) if locality_name else None - - def __str__(self) -> str: - """Return the certificate as a string.""" - return self._cert.public_bytes(serialization.Encoding.PEM).decode().strip() - - def __eq__(self, other: object) -> bool: - """Check if two Certificate objects are equal.""" - if not isinstance(other, Certificate): - return NotImplemented - return self.raw == other.raw - - @classmethod - def from_string(cls, certificate: str) -> "Certificate": - """Create a Certificate object from a certificate.""" - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - except ValueError as e: - logger.error("Could not load certificate: %s", e) - raise TLSCertificatesError("Could not load certificate") - - return cls(x509_object=certificate_object) - - def matches_private_key(self, private_key: PrivateKey) -> bool: - """Check if this certificate matches a given private key. - - Args: - private_key (PrivateKey): The private key to validate against. - - Returns: - bool: True if the certificate matches the private key, False otherwise. - """ - try: - cert_public_key = self._cert.public_key() - key_public_key = private_key._private_key.public_key() - - if not isinstance(cert_public_key, rsa.RSAPublicKey): - logger.warning("Certificate does not use RSA public key") - return False - - if not isinstance(key_public_key, rsa.RSAPublicKey): - logger.warning("Private key is not an RSA key") - return False - - return cert_public_key.public_numbers() == key_public_key.public_numbers() - except Exception as e: - logger.warning("Failed to validate certificate and private key match: %s", e) - return False - - @classmethod - def generate( - cls, - csr: "CertificateSigningRequest", - ca: "Certificate", - ca_private_key: "PrivateKey", - validity: timedelta, - is_ca: bool = False, - ) -> "Certificate": - """Generate a certificate from a CSR signed by the given CA and CA private key. - - Args: - csr: The certificate signing request. - ca: The CA certificate. - ca_private_key: The CA private key. - validity: The validity period of the certificate. - is_ca: Whether the generated certificate is a CA certificate. - - Returns: - Certificate: The generated certificate. - """ - # Ideally, this would be the constructor, but we can't add new - # required parameters to the constructor without breaking backwards - # compatibility. - private_key = serialization.load_pem_private_key( - str(ca_private_key).encode(), password=None - ) - assert isinstance(private_key, CertificateIssuerPrivateKeyTypes) - - # Create a certificate builder - cert_builder = x509.CertificateBuilder( - subject_name=csr._csr.subject, - # issuer_name=ca._cert.subject, # TODO: Validate this is correct, the old code used `issuer` - issuer_name=ca._cert.issuer, - public_key=csr._csr.public_key(), - serial_number=x509.random_serial_number(), - not_valid_before=datetime.now(timezone.utc), - not_valid_after=datetime.now(timezone.utc) + validity, - ) - extensions = _generate_certificate_request_extensions( - authority_key_identifier=ca._cert.extensions.get_extension_for_class( - x509.SubjectKeyIdentifier - ).value.key_identifier, - csr=csr._csr, - is_ca=is_ca, - ) - for extension in extensions: - try: - cert_builder = cert_builder.add_extension(extension.value, extension.critical) - except ValueError as e: - logger.error("Could not add extension to certificate: %s", e) - raise TLSCertificatesError("Could not add extension to certificate") from e - - # Sign the certificate with the CA's private key - cert = cert_builder.sign(private_key=private_key, algorithm=hashes.SHA256()) - _OWASPLogger().log_event( - event="certificate_generated", - level=logging.INFO, - description="Certificate generated from CSR", - common_name=csr.common_name, - is_ca=str(is_ca), - validity_days=str(validity.days), - ) - - return cls(x509_object=cert) - - @classmethod - def generate_self_signed_ca( - cls, - attributes: "CertificateRequestAttributes", - private_key: PrivateKey, - validity: timedelta, - ) -> "Certificate": - """Generate a self-signed CA certificate. - - Args: - attributes: The certificate request attributes. - private_key: The private key to sign the CA certificate. - validity: The validity period of the CA certificate. - - Returns: - Certificate: The generated CA certificate. - """ - assert isinstance(private_key._private_key, rsa.RSAPrivateKey) - - public_key = private_key._private_key.public_key() - - builder = x509.CertificateBuilder( - public_key=public_key, - serial_number=x509.random_serial_number(), - not_valid_before=datetime.now(timezone.utc), - not_valid_after=datetime.now(timezone.utc) + validity, - ) - - if subject_name := _extract_subject_name_attributes(attributes): - builder = builder.subject_name(subject_name).issuer_name(subject_name) - - builder = ( - builder.add_extension( - x509.SubjectKeyIdentifier.from_public_key(public_key), critical=False - ) - .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True) - .add_extension( - x509.KeyUsage( - digital_signature=True, - key_encipherment=True, - key_cert_sign=True, - key_agreement=False, - content_commitment=False, - data_encipherment=False, - crl_sign=False, - encipher_only=False, - decipher_only=False, - ), - critical=True, - ) - ) - - if san_extension := _san_extension( - email_address=attributes.email_address, - sans_dns=attributes.sans_dns, - sans_ip=attributes.sans_ip, - sans_oid=attributes.sans_oid, - ): - builder = builder.add_extension(san_extension, critical=False) - - cert = cls(x509_object=builder.sign(private_key._private_key, algorithm=hashes.SHA256())) - - _OWASPLogger().log_event( - event="ca_certificate_generated", - level=logging.INFO, - description="CA certificate generated", - common_name=cert.common_name, - validity_days=str(validity.days), - ) - - return cert - - def __hash__(self): - """Return the hash of the private key.""" - return hash(self.raw) - - -class CertificateSigningRequest: - """A representation of the certificate signing request.""" - - _csr: x509.CertificateSigningRequest - - def __init__( - self, - raw: Optional[str] = None, # Must remain first argument for backwards compatibility - # Old Interface fields (ignored) - common_name: Optional[str] = None, - sans_dns: Optional[Set[str]] = None, - sans_ip: Optional[Set[str]] = None, - sans_oid: Optional[Set[str]] = None, - email_address: Optional[str] = None, - organization: Optional[str] = None, - organizational_unit: Optional[str] = None, - country_name: Optional[str] = None, - state_or_province_name: Optional[str] = None, - locality_name: Optional[str] = None, - has_unique_identifier: Optional[bool] = None, - # End Old Interface fields - x509_object: Optional[x509.CertificateSigningRequest] = None, - ): - """Initialize the CertificateSigningRequest object. - - This initializer must maintain the old interface while also allowing - instantiation from an existing x509_object. It ignores all fields - other than raw and x509_object, preferring x509_object. - """ - if x509_object: - self._csr = x509_object - return - elif raw: - try: - self._csr = x509.load_pem_x509_csr(raw.encode()) - except ValueError as e: - logger.error("Could not load CSR: %s", e) - raise TLSCertificatesError("Could not load CSR") - return - raise ValueError("Either raw CSR string or x509_object must be provided") - - @property - def common_name(self) -> str: - """Return the common name of the CSR.""" - common_name = self._csr.subject.get_attributes_for_oid(NameOID.COMMON_NAME) - return str(common_name[0].value) if common_name else "" - - @property - def sans_dns(self) -> Set[str]: - """Return the DNS Subject Alternative Names of the CSR.""" - with suppress(x509.ExtensionNotFound): - sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - return {str(san) for san in sans.get_values_for_type(x509.DNSName)} - return set() - - @property - def sans_ip(self) -> Set[str]: - """Return the IP Subject Alternative Names of the CSR.""" - with suppress(x509.ExtensionNotFound): - sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - return {str(san) for san in sans.get_values_for_type(x509.IPAddress)} - return set() - - @property - def sans_oid(self) -> Set[str]: - """Return the OID Subject Alternative Names of the CSR.""" - with suppress(x509.ExtensionNotFound): - sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - return {str(san.dotted_string) for san in sans.get_values_for_type(x509.RegisteredID)} - return set() - - @property - def email_address(self) -> Optional[str]: - """Return the email address of the CSR.""" - email_address = self._csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) - return str(email_address[0].value) if email_address else None - - @property - def organization(self) -> Optional[str]: - """Return the organization name of the CSR.""" - organization = self._csr.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) - return str(organization[0].value) if organization else None - - @property - def organizational_unit(self) -> Optional[str]: - """Return the organizational unit name of the CSR.""" - organizational_unit = self._csr.subject.get_attributes_for_oid( - NameOID.ORGANIZATIONAL_UNIT_NAME - ) - return str(organizational_unit[0].value) if organizational_unit else None - - @property - def country_name(self) -> Optional[str]: - """Return the country name of the CSR.""" - country_name = self._csr.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) - return str(country_name[0].value) if country_name else None - - @property - def state_or_province_name(self) -> Optional[str]: - """Return the state or province name of the CSR.""" - state_or_province_name = self._csr.subject.get_attributes_for_oid( - NameOID.STATE_OR_PROVINCE_NAME - ) - return str(state_or_province_name[0].value) if state_or_province_name else None - - @property - def locality_name(self) -> Optional[str]: - """Return the locality name of the CSR.""" - locality_name = self._csr.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) - return str(locality_name[0].value) if locality_name else None - - @property - def has_unique_identifier(self) -> bool: - """Return whether the CSR has a unique identifier.""" - unique_identifier = self._csr.subject.get_attributes_for_oid( - NameOID.X500_UNIQUE_IDENTIFIER - ) - return bool(unique_identifier) - - @property - def raw(self) -> str: - """Return the PEM-formatted string representation of the CSR.""" - return self.__str__() - - def __str__(self) -> str: - """Return the CSR as a string.""" - return self._csr.public_bytes(serialization.Encoding.PEM).decode().strip() - - @property - def additional_critical_extensions(self) -> List[x509.ExtensionType]: - """Return additional critical extensions present on the CSR (excluding SAN).""" - extensions: List[x509.ExtensionType] = [] - for extension in self._csr.extensions: - if extension.critical and extension.oid != ExtensionOID.SUBJECT_ALTERNATIVE_NAME: - extensions.append(extension.value) - return extensions - - @classmethod - def from_string(cls, csr: str) -> "CertificateSigningRequest": - """Create a CertificateSigningRequest object from a CSR.""" - return cls(raw=csr) - - @classmethod - def from_csr(cls, csr: x509.CertificateSigningRequest) -> "CertificateSigningRequest": - """Create a CertificateSigningRequest object from a CSR.""" - return cls(x509_object=csr) - - def __eq__(self, other: object) -> bool: - """Check if two CertificateSigningRequest objects are equal.""" - if not isinstance(other, CertificateSigningRequest): - return NotImplemented - return self.raw == other.raw - - def __hash__(self): - """Return the hash of the private key.""" - return hash(self.raw) - - def matches_certificate(self, certificate: Certificate) -> bool: - """Check if this CSR matches a given certificate. - - Args: - certificate (Certificate): The certificate to validate against. - - Returns: - bool: True if the CSR matches the certificate, False otherwise. - """ - return self._csr.public_key() == certificate._cert.public_key() - - def matches_private_key(self, key: PrivateKey) -> bool: - """Check if a CSR matches a private key. - - This function only works with RSA keys. - - Args: - key (PrivateKey): Private key - Returns: - bool: True/False depending on whether the CSR matches the private key. - """ - try: - key_object_public_key = key._private_key.public_key() - csr_object_public_key = self._csr.public_key() - if not isinstance(key_object_public_key, rsa.RSAPublicKey): - logger.warning("Key is not an RSA key") - return False - if not isinstance(csr_object_public_key, rsa.RSAPublicKey): - logger.warning("CSR is not an RSA key") - return False - if ( - csr_object_public_key.public_numbers().n - != key_object_public_key.public_numbers().n - ): - logger.warning("Public key numbers between CSR and key do not match") - return False - except ValueError: - logger.warning("Could not load certificate or CSR.") - return False - return True - - def get_sha256_hex(self) -> str: - """Calculate the hash of the provided data and return the hexadecimal representation.""" - digest = hashes.Hash(hashes.SHA256()) - digest.update(self.raw.encode()) - return digest.finalize().hex() - - def sign( - self, ca: Certificate, ca_private_key: PrivateKey, validity: timedelta, is_ca: bool = False - ) -> Certificate: - """Sign this CSR with the given CA and CA private key. - - Args: - ca: The CA certificate. - ca_private_key: The CA private key. - validity: The validity period of the certificate. - is_ca: Whether the generated certificate is a CA certificate. - - Returns: - Certificate: The signed certificate. - """ - return Certificate.generate( - csr=self, - ca=ca, - ca_private_key=ca_private_key, - validity=validity, - is_ca=is_ca, - ) - - @classmethod - def generate( - cls, - attributes: "CertificateRequestAttributes", - private_key: PrivateKey, - ) -> "CertificateSigningRequest": - """Generate a CSR using the supplied attributes and private key. - - Args: - attributes (CertificateRequestAttributes): Certificate request attributes - private_key (PrivateKey): Private key - Returns: - CertificateSigningRequest: CSR - """ - signing_key = private_key._private_key - assert isinstance(signing_key, CertificateIssuerPrivateKeyTypes) - - csr_builder = x509.CertificateSigningRequestBuilder() - if subject_name := _extract_subject_name_attributes(attributes): - csr_builder = csr_builder.subject_name(subject_name) - - _sans: List[x509.GeneralName] = [] - if attributes.sans_oid: - _sans.extend( - [x509.RegisteredID(x509.ObjectIdentifier(san)) for san in attributes.sans_oid] - ) - if attributes.sans_ip: - _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in attributes.sans_ip]) - if attributes.sans_dns: - _sans.extend([x509.DNSName(san) for san in attributes.sans_dns]) - if _sans: - csr_builder = csr_builder.add_extension( - x509.SubjectAlternativeName(set(_sans)), critical=False - ) - if attributes.additional_critical_extensions: - for extension in attributes.additional_critical_extensions: - csr_builder = csr_builder.add_extension(extension, critical=True) - signed_certificate_request = csr_builder.sign(signing_key, hashes.SHA256()) - return cls(x509_object=signed_certificate_request) - - -class CertificateRequestAttributes: - """A representation of the certificate request attributes.""" - - def __init__( - self, - common_name: Optional[str] = None, - sans_dns: Optional[Collection[str]] = None, - sans_ip: Optional[Collection[str]] = None, - sans_oid: Optional[Collection[str]] = None, - email_address: Optional[str] = None, - organization: Optional[str] = None, - organizational_unit: Optional[str] = None, - country_name: Optional[str] = None, - state_or_province_name: Optional[str] = None, - locality_name: Optional[str] = None, - is_ca: bool = False, - add_unique_id_to_subject_name: bool = True, - additional_critical_extensions: Optional[Collection[x509.ExtensionType]] = None, - ): - if not common_name and not sans_dns and not sans_ip and not sans_oid: - raise ValueError( - "At least one of common_name, sans_dns, sans_ip, or sans_oid must be provided" - ) - self._common_name = common_name - self._sans_dns = set(sans_dns) if sans_dns else None - self._sans_ip = set(sans_ip) if sans_ip else None - self._sans_oid = set(sans_oid) if sans_oid else None - self._email_address = email_address - self._organization = organization - self._organizational_unit = organizational_unit - self._country_name = country_name - self._state_or_province_name = state_or_province_name - self._locality_name = locality_name - self._is_ca = is_ca - self._add_unique_id_to_subject_name = add_unique_id_to_subject_name - self._additional_critical_extensions = list(additional_critical_extensions or []) - - @property - def common_name(self) -> str: - """Return the common name.""" - # For legacy interface compatibility, return empty string if not set - return self._common_name if self._common_name else "" - - @property - def sans_dns(self) -> Optional[Set[str]]: - """Return the DNS Subject Alternative Names.""" - return self._sans_dns - - @property - def sans_ip(self) -> Optional[Set[str]]: - """Return the IP Subject Alternative Names.""" - return self._sans_ip - - @property - def sans_oid(self) -> Optional[Set[str]]: - """Return the OID Subject Alternative Names.""" - return self._sans_oid - - @property - def email_address(self) -> Optional[str]: - """Return the email address.""" - return self._email_address - - @property - def organization(self) -> Optional[str]: - """Return the organization name.""" - return self._organization - - @property - def organizational_unit(self) -> Optional[str]: - """Return the organizational unit name.""" - return self._organizational_unit - - @property - def country_name(self) -> Optional[str]: - """Return the country name.""" - return self._country_name - - @property - def state_or_province_name(self) -> Optional[str]: - """Return the state or province name.""" - return self._state_or_province_name - - @property - def locality_name(self) -> Optional[str]: - """Return the locality name.""" - return self._locality_name - - @property - def is_ca(self) -> bool: - """Return whether the certificate is a CA certificate.""" - return self._is_ca - - @property - def add_unique_id_to_subject_name(self) -> bool: - """Return whether to add a unique identifier to the subject name.""" - return self._add_unique_id_to_subject_name - - @property - def additional_critical_extensions(self) -> List[x509.ExtensionType]: - """Return additional critical extensions to be added to the CSR.""" - return self._additional_critical_extensions - - @classmethod - def from_csr( - cls, csr: CertificateSigningRequest, is_ca: bool - ) -> "CertificateRequestAttributes": - """Create CertificateRequestAttributes from a CertificateSigningRequest. - - Args: - csr: The CSR to extract attributes from. - is_ca: Whether a CA certificate is being requested. - - Returns: - CertificateRequestAttributes: The extracted attributes. - """ - return cls( - common_name=csr.common_name, - sans_dns=csr.sans_dns, - sans_ip=csr.sans_ip, - sans_oid=csr.sans_oid, - email_address=csr.email_address, - organization=csr.organization, - organizational_unit=csr.organizational_unit, - country_name=csr.country_name, - state_or_province_name=csr.state_or_province_name, - locality_name=csr.locality_name, - is_ca=is_ca, - add_unique_id_to_subject_name=csr.has_unique_identifier, - additional_critical_extensions=csr.additional_critical_extensions, - ) - - def __eq__(self, other: object) -> bool: - """Check if two CertificateRequestAttributes objects are equal.""" - if not isinstance(other, CertificateRequestAttributes): - return NotImplemented - return ( - self.common_name == other.common_name - and self.sans_dns == other.sans_dns - and self.sans_ip == other.sans_ip - and self.sans_oid == other.sans_oid - and self.email_address == other.email_address - and self.organization == other.organization - and self.organizational_unit == other.organizational_unit - and self.country_name == other.country_name - and self.state_or_province_name == other.state_or_province_name - and self.locality_name == other.locality_name - and self.is_ca == other.is_ca - and self.add_unique_id_to_subject_name == other.add_unique_id_to_subject_name - and self.additional_critical_extensions == other.additional_critical_extensions - ) - - def is_valid(self) -> bool: - """Validate the attributes of the certificate request. - - Returns: - bool: True if the attributes are valid, False otherwise. - """ - if not self.common_name and not self.sans_dns and not self.sans_ip and not self.sans_oid: - logger.warning( - "At least one of common_name, sans_dns, sans_ip, or sans_oid must be provided" - ) - return False - return True - - def generate_csr( - self, - private_key: PrivateKey, - ) -> CertificateSigningRequest: - """Generate a CSR using the current attributes and a private key. - - Args: - private_key (PrivateKey): Private key to sign the CSR. - - Returns: - CertificateSigningRequest: The generated CSR. - """ - return CertificateSigningRequest.generate(self, private_key) - - -@dataclass(frozen=True) -class ProviderCertificate: - """This class represents a certificate provided by the TLS provider.""" - - relation_id: int - certificate: Certificate - certificate_signing_request: CertificateSigningRequest - ca: Certificate - chain: List[Certificate] - revoked: Optional[bool] = None - - def to_json(self) -> str: - """Return the object as a JSON string. - - Returns: - str: JSON representation of the object - """ - return json.dumps( - { - "csr": str(self.certificate_signing_request), - "certificate": str(self.certificate), - "ca": str(self.ca), - "chain": [str(cert) for cert in self.chain], - "revoked": self.revoked, - } - ) - - -@dataclass(frozen=True) -class RequirerCertificateRequest: - """This class represents a certificate signing request requested by a specific TLS requirer.""" - - relation_id: int - certificate_signing_request: CertificateSigningRequest - is_ca: bool - - -class CertificateAvailableEvent(EventBase): - """Charm Event triggered when a TLS certificate is available.""" - - def __init__( - self, - handle: Handle, - certificate: Certificate, - certificate_signing_request: CertificateSigningRequest, - ca: Certificate, - chain: List[Certificate], - ): - super().__init__(handle) - self.certificate = certificate - self.certificate_signing_request = certificate_signing_request - self.ca = ca - self.chain = chain - - def snapshot(self) -> dict: - """Return snapshot.""" - return { - "certificate": str(self.certificate), - "certificate_signing_request": str(self.certificate_signing_request), - "ca": str(self.ca), - "chain": json.dumps([str(certificate) for certificate in self.chain]), - } - - def restore(self, snapshot: dict): - """Restore snapshot.""" - self.certificate = Certificate.from_string(snapshot["certificate"]) - self.certificate_signing_request = CertificateSigningRequest.from_string( - snapshot["certificate_signing_request"] - ) - self.ca = Certificate.from_string(snapshot["ca"]) - chain_strs = json.loads(snapshot["chain"]) - self.chain = [Certificate.from_string(chain_str) for chain_str in chain_strs] - - def chain_as_pem(self) -> str: - """Return full certificate chain as a PEM string.""" - return "\n\n".join([str(cert) for cert in self.chain]) - - -def generate_private_key( - key_size: int = 2048, - public_exponent: int = 65537, -) -> PrivateKey: - """Generate a private key with the RSA algorithm. - - Args: - key_size (int): Key size in bits, must be at least 2048 bits - public_exponent: Public exponent. - - Returns: - PrivateKey: Private Key - """ - warnings.warn( - "generate_private_key() is deprecated. Use PrivateKey.generate() instead.", - DeprecationWarning, - ) - return PrivateKey.generate(key_size=key_size, public_exponent=public_exponent) - - -def calculate_relative_datetime(target_time: datetime, fraction: float) -> datetime: - """Calculate a datetime that is a given percentage from now to a target time. - - Args: - target_time (datetime): The future datetime to interpolate towards. - fraction (float): Fraction of the interval from now to target_time (0.0-1.0). - 1.0 means return target_time, - 0.9 means return the time after 90% of the interval has passed, - and 0.0 means return now. - """ - if fraction <= 0.0 or fraction > 1.0: - raise ValueError("Invalid fraction. Must be between 0.0 and 1.0") - now = datetime.now(timezone.utc) - time_until_target = target_time - now - return now + time_until_target * fraction - - -def chain_has_valid_order(chain: List[str]) -> bool: - """Check if the chain has a valid order. - - Validates that each certificate in the chain is properly signed by the next certificate. - The chain should be ordered from leaf to root, where each certificate is signed by - the next one in the chain. - - Args: - chain (List[str]): List of certificates in PEM format, ordered from leaf to root - - Returns: - bool: True if the chain has a valid order, False otherwise. - """ - if len(chain) < 2: - return True - - try: - for i in range(len(chain) - 1): - cert = x509.load_pem_x509_certificate(chain[i].encode()) - issuer = x509.load_pem_x509_certificate(chain[i + 1].encode()) - cert.verify_directly_issued_by(issuer) - return True - except (ValueError, TypeError, InvalidSignature): - return False - - -def generate_csr( # noqa: C901 - private_key: PrivateKey, - common_name: str, - sans_dns: Optional[FrozenSet[str]] = frozenset(), - sans_ip: Optional[FrozenSet[str]] = frozenset(), - sans_oid: Optional[FrozenSet[str]] = frozenset(), - organization: Optional[str] = None, - organizational_unit: Optional[str] = None, - email_address: Optional[str] = None, - country_name: Optional[str] = None, - locality_name: Optional[str] = None, - state_or_province_name: Optional[str] = None, - add_unique_id_to_subject_name: bool = True, -) -> CertificateSigningRequest: - """Generate a CSR using private key and subject. - - Args: - private_key (PrivateKey): Private key - common_name (str): Common name - sans_dns (FrozenSet[str]): DNS Subject Alternative Names - sans_ip (FrozenSet[str]): IP Subject Alternative Names - sans_oid (FrozenSet[str]): OID Subject Alternative Names - organization (Optional[str]): Organization name - organizational_unit (Optional[str]): Organizational unit name - email_address (Optional[str]): Email address - country_name (Optional[str]): Country name - state_or_province_name (Optional[str]): State or province name - locality_name (Optional[str]): Locality name - add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's - subject name. Always leave to "True" when the CSR is used to request certificates - using the tls-certificates relation. - - Returns: - CertificateSigningRequest: CSR - """ - warnings.warn( - "generate_csr() is deprecated. Use CertificateRequestAttributes.generate_csr() or CertificateSigningRequest.generate() instead.", - DeprecationWarning, - ) - return CertificateRequestAttributes( - common_name=common_name, - sans_dns=sans_dns, - sans_ip=sans_ip, - sans_oid=sans_oid, - organization=organization, - organizational_unit=organizational_unit, - email_address=email_address, - country_name=country_name, - state_or_province_name=state_or_province_name, - locality_name=locality_name, - add_unique_id_to_subject_name=add_unique_id_to_subject_name, - ).generate_csr(private_key=private_key) - - -def generate_ca( - private_key: PrivateKey, - validity: timedelta, - common_name: str, - sans_dns: Optional[FrozenSet[str]] = frozenset(), - sans_ip: Optional[FrozenSet[str]] = frozenset(), - sans_oid: Optional[FrozenSet[str]] = frozenset(), - organization: Optional[str] = None, - organizational_unit: Optional[str] = None, - email_address: Optional[str] = None, - country_name: Optional[str] = None, - state_or_province_name: Optional[str] = None, - locality_name: Optional[str] = None, -) -> Certificate: - """Generate a self signed CA Certificate. - - Args: - private_key: Private key - validity: Certificate validity time - common_name: Common Name that can be an IP or a Full Qualified Domain Name (FQDN). - sans_dns: DNS Subject Alternative Names - sans_ip: IP Subject Alternative Names - sans_oid: OID Subject Alternative Names - organization: Organization name - organizational_unit: Organizational unit name - email_address: Email address - country_name: Certificate Issuing country - state_or_province_name: Certificate Issuing state or province - locality_name: Certificate Issuing locality - - Returns: - CA Certificate. - """ - warnings.warn( - "generate_ca() is deprecated. Use Certificate.generate_self_signed_ca() instead.", - DeprecationWarning, - ) - attributes = CertificateRequestAttributes( - common_name=common_name, - sans_dns=sans_dns, - sans_ip=sans_ip, - sans_oid=sans_oid, - organization=organization, - organizational_unit=organizational_unit, - email_address=email_address, - country_name=country_name, - state_or_province_name=state_or_province_name, - locality_name=locality_name, - is_ca=True, - ) - return Certificate.generate_self_signed_ca(attributes, private_key, validity) - - -def _san_extension( - email_address: Optional[str] = None, - sans_dns: Optional[Collection[str]] = frozenset(), - sans_ip: Optional[Collection[str]] = frozenset(), - sans_oid: Optional[Collection[str]] = frozenset(), -) -> Optional[x509.SubjectAlternativeName]: - sans: List[x509.GeneralName] = [] - if email_address: - # If an e-mail address was provided, it should always be in the SAN - sans.append(x509.RFC822Name(email_address)) - if sans_dns: - sans.extend([x509.DNSName(san) for san in sans_dns]) - if sans_ip: - sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) - if sans_oid: - sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) - if not sans: - return None - return x509.SubjectAlternativeName(sans) - - -def generate_certificate( - csr: CertificateSigningRequest, - ca: Certificate, - ca_private_key: PrivateKey, - validity: timedelta, - is_ca: bool = False, -) -> Certificate: - """Generate a TLS certificate based on a CSR. - - Args: - csr (CertificateSigningRequest): CSR - ca (Certificate): CA Certificate - ca_private_key (PrivateKey): CA private key - validity (timedelta): Certificate validity time - is_ca (bool): Whether the certificate is a CA certificate - - Returns: - Certificate: Certificate - """ - warnings.warn( - "generate_certificate() is deprecated. Use Certificate.generate() instead.", - DeprecationWarning, - ) - return Certificate.generate( - csr=csr, - ca=ca, - ca_private_key=ca_private_key, - validity=validity, - is_ca=is_ca, - ) - - -def _extract_subject_name_attributes( - attributes: CertificateRequestAttributes, -) -> Optional[x509.Name]: - subject_name_attributes = [] - if attributes.common_name: - subject_name_attributes.append( - x509.NameAttribute(x509.NameOID.COMMON_NAME, attributes.common_name) - ) - if attributes.add_unique_id_to_subject_name: - unique_identifier = uuid.uuid4() - subject_name_attributes.append( - x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier)) - ) - if attributes.organization: - subject_name_attributes.append( - x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, attributes.organization) - ) - if attributes.organizational_unit: - subject_name_attributes.append( - x509.NameAttribute( - x509.NameOID.ORGANIZATIONAL_UNIT_NAME, - attributes.organizational_unit, - ) - ) - if attributes.email_address: - subject_name_attributes.append( - x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, attributes.email_address) - ) - if attributes.country_name: - subject_name_attributes.append( - x509.NameAttribute(x509.NameOID.COUNTRY_NAME, attributes.country_name) - ) - if attributes.state_or_province_name: - subject_name_attributes.append( - x509.NameAttribute( - x509.NameOID.STATE_OR_PROVINCE_NAME, - attributes.state_or_province_name, - ) - ) - if attributes.locality_name: - subject_name_attributes.append( - x509.NameAttribute(x509.NameOID.LOCALITY_NAME, attributes.locality_name) - ) - - if subject_name_attributes: - return x509.Name(subject_name_attributes) - - return None - - -def _generate_certificate_request_extensions( - authority_key_identifier: bytes, - csr: x509.CertificateSigningRequest, - is_ca: bool, -) -> List[x509.Extension]: - """Generate a list of certificate extensions from a CSR and other known information. - - Args: - authority_key_identifier (bytes): Authority key identifier - csr (x509.CertificateSigningRequest): CSR - is_ca (bool): Whether the certificate is a CA certificate - - Returns: - List[x509.Extension]: List of extensions - """ - cert_extensions_list: List[x509.Extension] = [ - x509.Extension( - oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, - value=x509.AuthorityKeyIdentifier( - key_identifier=authority_key_identifier, - authority_cert_issuer=None, - authority_cert_serial_number=None, - ), - critical=False, - ), - x509.Extension( - oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, - value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), - critical=False, - ), - x509.Extension( - oid=ExtensionOID.BASIC_CONSTRAINTS, - critical=True, - value=x509.BasicConstraints(ca=is_ca, path_length=None), - ), - ] - if sans := _generate_subject_alternative_name_extension(csr): - cert_extensions_list.append(sans) - - if is_ca: - cert_extensions_list.append( - x509.Extension( - ExtensionOID.KEY_USAGE, - critical=True, - value=x509.KeyUsage( - digital_signature=False, - content_commitment=False, - key_encipherment=False, - data_encipherment=False, - key_agreement=False, - key_cert_sign=True, - crl_sign=True, - encipher_only=False, - decipher_only=False, - ), - ) - ) - - existing_oids = {ext.oid for ext in cert_extensions_list} - for extension in csr.extensions: - if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: - continue - if extension.oid in existing_oids: - logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) - continue - cert_extensions_list.append(extension) - - return cert_extensions_list - - -def _generate_subject_alternative_name_extension( - csr: x509.CertificateSigningRequest, -) -> Optional[x509.Extension]: - sans: List[x509.GeneralName] = [] - try: - loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) - sans.extend( - [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] - ) - sans.extend( - [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] - ) - sans.extend( - [ - x509.RegisteredID(oid) - for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) - ] - ) - sans.extend( - [ - x509.RFC822Name(name) - for name in loaded_san_ext.value.get_values_for_type(x509.RFC822Name) - ] - ) - except x509.ExtensionNotFound: - pass - # If email is present in the CSR Subject, make sure it is also in the SANS - # to conform to RFC 5280. - email = csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) - if email: - email_rfc822 = x509.RFC822Name(str(email[0].value)) - if email_rfc822 not in sans: - sans.append(email_rfc822) - - return ( - x509.Extension( - oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, - critical=False, - value=x509.SubjectAlternativeName(sans), - ) - if sans - else None - ) - - -class CertificatesRequirerCharmEvents(CharmEvents): - """List of events that the TLS Certificates requirer charm can leverage.""" - - certificate_available = EventSource(CertificateAvailableEvent) - - -class TLSCertificatesRequiresV4(Object): - """A class to manage the TLS certificates interface for a unit or app.""" - - on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] - - def __init__( - self, - charm: CharmBase, - relationship_name: str, - certificate_requests: List[CertificateRequestAttributes], - mode: Mode = Mode.UNIT, - refresh_events: List[BoundEvent] = [], - private_key: Optional[PrivateKey] = None, - renewal_relative_time: float = 0.9, - ): - """Create a new instance of the TLSCertificatesRequiresV4 class. - - Args: - charm (CharmBase): The charm instance to relate to. - relationship_name (str): The name of the relation that provides the certificates. - certificate_requests (List[CertificateRequestAttributes]): - A list with the attributes of the certificate requests. - mode (Mode): Whether to use unit or app certificates mode. Default is Mode.UNIT. - In UNIT mode the requirer will place the csr in the unit relation data. - Each unit will manage its private key, - certificate signing request and certificate. - UNIT mode is for use cases where each unit has its own identity. - If you don't know which mode to use, you likely need UNIT. - In APP mode the leader unit will place the csr in the app relation databag. - APP mode is for use cases where the underlying application needs the certificate - for example using it as an intermediate CA to sign other certificates. - The certificate can only be accessed by the leader unit. - refresh_events (List[BoundEvent]): A list of events to trigger a refresh of - the certificates. - private_key (Optional[PrivateKey]): The private key to use for the certificates. - If provided, it will be used instead of generating a new one. - If the key is not valid an exception will be raised. - Using this parameter is discouraged, - having to pass around private keys manually can be a security concern. - Allowing the library to generate and manage the key is the more secure approach. - renewal_relative_time (float): The time to renew the certificate relative to its - expiry. - Default is 0.9, meaning 90% of the validity period. - The minimum value is 0.5, meaning 50% of the validity period. - If an invalid value is provided, an exception will be raised. - """ - super().__init__(charm, relationship_name) - if not JujuVersion.from_environ().has_secrets: - logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") - if not self._mode_is_valid(mode): - raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP") - for certificate_request in certificate_requests: - if not certificate_request.is_valid(): - raise TLSCertificatesError("Invalid certificate request") - self.charm = charm - self.relationship_name = relationship_name - self.certificate_requests = certificate_requests - self.mode = mode - if private_key and not private_key.is_valid(): - raise TLSCertificatesError("Invalid private key") - if renewal_relative_time <= 0.5 or renewal_relative_time > 1.0: - raise TLSCertificatesError( - "Invalid renewal relative time. Must be between 0.5 and 1.0" - ) - self._private_key = private_key - self.renewal_relative_time = renewal_relative_time - self.framework.observe(charm.on[relationship_name].relation_created, self._configure) - self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) - self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - self.framework.observe(charm.on.secret_remove, self._on_secret_remove) - for event in refresh_events: - self.framework.observe(event, self._configure) - self._security_logger = _OWASPLogger(application=f"tls-certificates-{charm.app.name}") - - def _configure(self, _: Optional[EventBase] = None): - """Handle TLS Certificates Relation Data. - - This method is called during any TLS relation event. - It will generate a private key if it doesn't exist yet. - It will send certificate requests if they haven't been sent yet. - It will find available certificates and emit events. - """ - if not self._tls_relation_created(): - logger.debug("TLS relation not created yet.") - return - self._ensure_private_key() - self._cleanup_certificate_requests() - self._send_certificate_requests() - self._find_available_certificates() - - def _mode_is_valid(self, mode: Mode) -> bool: - return mode in [Mode.UNIT, Mode.APP] - - def _validate_secret_exists(self, secret: Secret) -> None: - secret.get_info() # Will raise `SecretNotFoundError` if the secret does not exist - - def _on_secret_remove(self, event: SecretRemoveEvent) -> None: - """Handle Secret Removed Event.""" - try: - # Ensure the secret exists before trying to remove it, otherwise - # the unit could be stuck in an error state. See the docstring of - # `remove_revision` and the below issue for more information. - # https://github.com/juju/juju/issues/19036 - self._validate_secret_exists(event.secret) - event.secret.remove_revision(event.revision) - except SecretNotFoundError: - logger.warning( - "No such secret %s, nothing to remove", - event.secret.label or event.secret.id, - ) - return - - def _on_secret_expired(self, event: SecretExpiredEvent) -> None: - """Handle Secret Expired Event. - - Renews certificate requests and removes the expired secret. - """ - if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-certificate"): - return - try: - csr_str = event.secret.get_content(refresh=True)["csr"] - except ModelError: - logger.error("Failed to get CSR from secret - Skipping") - return - csr = CertificateSigningRequest.from_string(csr_str) - self._renew_certificate_request(csr) - event.secret.remove_all_revisions() - - def sync(self) -> None: - """Sync TLS Certificates Relation Data. - - This method allows the requirer to sync the TLS certificates relation data - without waiting for the refresh events to be triggered. - """ - self._configure() - - def renew_certificate(self, certificate: ProviderCertificate) -> None: - """Request the renewal of the provided certificate.""" - certificate_signing_request = certificate.certificate_signing_request - secret_label = self._get_csr_secret_label(certificate_signing_request) - try: - secret = self.model.get_secret(label=secret_label) - except SecretNotFoundError: - logger.warning("No matching secret found - Skipping renewal") - return - current_csr = secret.get_content(refresh=True).get("csr", "") - if current_csr != str(certificate_signing_request): - logger.warning("No matching CSR found - Skipping renewal") - return - self._renew_certificate_request(certificate_signing_request) - secret.remove_all_revisions() - - def _renew_certificate_request(self, csr: CertificateSigningRequest): - """Remove existing CSR from relation data and create a new one.""" - self._remove_requirer_csr_from_relation_data(csr) - self._send_certificate_requests() - logger.info("Renewed certificate request") - - def _remove_requirer_csr_from_relation_data(self, csr: CertificateSigningRequest) -> None: - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.debug("No relation: %s", self.relationship_name) - return - if not self.get_csrs_from_requirer_relation_data(): - logger.info("No CSRs in relation data - Doing nothing") - return - app_or_unit = self._get_app_or_unit() - try: - requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) - except DataValidationError: - logger.warning("Invalid relation data - Skipping removal of CSR") - return - new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests) - for requirer_csr in new_relation_data: - if requirer_csr.certificate_signing_request.strip() == str(csr).strip(): - new_relation_data.remove(requirer_csr) - try: - _RequirerData(certificate_signing_requests=new_relation_data).dump( - relation.data[app_or_unit] - ) - logger.info("Removed CSR from relation data") - except ModelError: - logger.warning("Failed to update relation data") - - def _get_app_or_unit(self) -> Union[Application, Unit]: - """Return the unit or app object based on the mode.""" - if self.mode == Mode.UNIT: - return self.model.unit - elif self.mode == Mode.APP: - return self.model.app - raise TLSCertificatesError("Invalid mode") - - @property - def private_key(self) -> Optional[PrivateKey]: - """Return the private key.""" - if self._private_key: - return self._private_key - if not self._private_key_generated(): - return None - secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) - private_key = secret.get_content(refresh=True)["private-key"] - return PrivateKey.from_string(private_key) - - def _ensure_private_key(self) -> None: - """Make sure there is a private key to be used. - - It will make sure there is a private key passed by the charm using the private_key - parameter or generate a new one otherwise. - """ - # Remove the generated private key - # if one has been passed by the charm using the private_key parameter - if self._private_key: - self._remove_private_key_secret() - return - if self._private_key_generated(): - logger.debug("Private key already generated") - return - self._generate_private_key() - - def regenerate_private_key(self) -> None: - """Regenerate the private key. - - Generate a new private key, remove old certificate requests and send new ones. - - Raises: - TLSCertificatesError: If the private key is passed by the charm using the - private_key parameter. - """ - if self._private_key: - raise TLSCertificatesError( - "Private key is passed by the charm through the private_key parameter, " - "this function can't be used" - ) - if not self._private_key_generated(): - logger.warning("No private key to regenerate") - return - self._generate_private_key() - self._cleanup_certificate_requests() - self._send_certificate_requests() - - def _generate_private_key(self) -> None: - """Generate a new private key and store it in a secret. - - This is the case when the private key used is generated by the library. - and not passed by the charm using the private_key parameter. - """ - self._store_private_key_in_secret(generate_private_key()) - logger.info("Private key generated") - - def _private_key_generated(self) -> bool: - """Check if a private key is stored in a secret. - - This is the case when the private key used is generated by the library. - This should not exist when the private key used - is passed by the charm using the private_key parameter. - """ - try: - secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) - secret.get_content(refresh=True) - return True - except SecretNotFoundError: - return False - - def _store_private_key_in_secret(self, private_key: PrivateKey) -> None: - try: - secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) - secret.set_content({"private-key": str(private_key)}) - secret.get_content(refresh=True) - except SecretNotFoundError: - self.charm.unit.add_secret( - content={"private-key": str(private_key)}, - label=self._get_private_key_secret_label(), - ) - - def _remove_private_key_secret(self) -> None: - """Remove the private key secret.""" - try: - secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) - secret.remove_all_revisions() - except SecretNotFoundError: - logger.warning("Private key secret not found, nothing to remove") - - def _csr_matches_certificate_request( - self, certificate_signing_request: CertificateSigningRequest, is_ca: bool - ) -> bool: - for certificate_request in self.certificate_requests: - if certificate_request == CertificateRequestAttributes.from_csr( - certificate_signing_request, - is_ca, - ): - return True - return False - - def _certificate_requested(self, certificate_request: CertificateRequestAttributes) -> bool: - if not self.private_key: - return False - csr = self._certificate_requested_for_attributes(certificate_request) - if not csr: - return False - if not csr.certificate_signing_request.matches_private_key(key=self.private_key): - return False - return True - - def _certificate_requested_for_attributes( - self, - certificate_request: CertificateRequestAttributes, - ) -> Optional[RequirerCertificateRequest]: - for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if certificate_request == CertificateRequestAttributes.from_csr( - requirer_csr.certificate_signing_request, - requirer_csr.is_ca, - ): - return requirer_csr - return None - - def get_csrs_from_requirer_relation_data(self) -> List[RequirerCertificateRequest]: - """Return list of requirer's CSRs from relation data.""" - if self.mode == Mode.APP and not self.model.unit.is_leader(): - logger.debug("Not a leader unit - Skipping") - return [] - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.debug("No relation: %s", self.relationship_name) - return [] - app_or_unit = self._get_app_or_unit() - try: - requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) - except DataValidationError: - logger.warning("Invalid relation data") - return [] - requirer_csrs = [] - for csr in requirer_relation_data.certificate_signing_requests: - requirer_csrs.append( - RequirerCertificateRequest( - relation_id=relation.id, - certificate_signing_request=CertificateSigningRequest.from_string( - csr.certificate_signing_request - ), - is_ca=csr.ca if csr.ca else False, - ) - ) - return requirer_csrs - - def get_provider_certificates(self) -> List[ProviderCertificate]: - """Return list of certificates from the provider's relation data.""" - return self._load_provider_certificates() - - def _load_provider_certificates(self) -> List[ProviderCertificate]: - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.debug("No relation: %s", self.relationship_name) - return [] - if not relation.app: - logger.debug("No remote app in relation: %s", self.relationship_name) - return [] - try: - provider_relation_data = _ProviderApplicationData.load(relation.data[relation.app]) - except DataValidationError: - logger.warning("Invalid relation data") - return [] - return [ - certificate.to_provider_certificate(relation_id=relation.id) - for certificate in provider_relation_data.certificates - ] - - def _request_certificate(self, csr: CertificateSigningRequest, is_ca: bool) -> None: - """Add CSR to relation data.""" - if self.mode == Mode.APP and not self.model.unit.is_leader(): - logger.debug("Not a leader unit - Skipping") - return - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.debug("No relation: %s", self.relationship_name) - return - new_csr = _CertificateSigningRequest( - certificate_signing_request=str(csr).strip(), ca=is_ca - ) - app_or_unit = self._get_app_or_unit() - try: - requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) - except DataValidationError: - requirer_relation_data = _RequirerData( - certificate_signing_requests=[], - ) - new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests) - new_relation_data.append(new_csr) - try: - _RequirerData(certificate_signing_requests=new_relation_data).dump( - relation.data[app_or_unit] - ) - logger.info("Certificate signing request added to relation data.") - except ModelError: - logger.warning("Failed to update relation data") - - def _send_certificate_requests(self): - if not self.private_key: - logger.debug("Private key not generated yet.") - return - for certificate_request in self.certificate_requests: - if not self._certificate_requested(certificate_request): - csr = certificate_request.generate_csr( - private_key=self.private_key, - ) - if not csr: - logger.warning("Failed to generate CSR") - continue - self._request_certificate(csr=csr, is_ca=certificate_request.is_ca) - - def get_assigned_certificate( - self, certificate_request: CertificateRequestAttributes - ) -> Tuple[Optional[ProviderCertificate], Optional[PrivateKey]]: - """Get the certificate that was assigned to the given certificate request.""" - for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if certificate_request == CertificateRequestAttributes.from_csr( - requirer_csr.certificate_signing_request, - requirer_csr.is_ca, - ): - return self._find_certificate_in_relation_data(requirer_csr), self.private_key - return None, None - - def get_assigned_certificates( - self, - ) -> Tuple[List[ProviderCertificate], Optional[PrivateKey]]: - """Get a list of certificates that were assigned to this or app.""" - assigned_certificates = [] - for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if cert := self._find_certificate_in_relation_data(requirer_csr): - assigned_certificates.append(cert) - return assigned_certificates, self.private_key - - def _find_certificate_in_relation_data( - self, csr: RequirerCertificateRequest - ) -> Optional[ProviderCertificate]: - """Return the certificate that matches the given CSR, validated against the private key.""" - if not self.private_key: - return None - for provider_certificate in self.get_provider_certificates(): - if provider_certificate.certificate_signing_request == csr.certificate_signing_request: - if provider_certificate.certificate.is_ca and not csr.is_ca: - logger.warning("Non CA certificate requested, got a CA certificate, ignoring") - continue - elif not provider_certificate.certificate.is_ca and csr.is_ca: - logger.warning("CA certificate requested, got a non CA certificate, ignoring") - continue - if not provider_certificate.certificate.matches_private_key(self.private_key): - logger.warning( - "Certificate does not match the private key. Ignoring invalid certificate." - ) - continue - return provider_certificate - return None - - def _find_available_certificates(self): - """Find available certificates and emit events. - - This method will find certificates that are available for the requirer's CSRs. - If a certificate is found, it will be set as a secret and an event will be emitted. - If a certificate is revoked, the secret will be removed and an event will be emitted. - """ - requirer_csrs = self.get_csrs_from_requirer_relation_data() - csrs = [csr.certificate_signing_request for csr in requirer_csrs] - provider_certificates = self.get_provider_certificates() - for provider_certificate in provider_certificates: - if provider_certificate.certificate_signing_request in csrs: - secret_label = self._get_csr_secret_label( - provider_certificate.certificate_signing_request - ) - if provider_certificate.revoked: - with suppress(SecretNotFoundError): - logger.debug( - "Removing secret with label %s", - secret_label, - ) - secret = self.model.get_secret(label=secret_label) - secret.remove_all_revisions() - else: - if not self._csr_matches_certificate_request( - certificate_signing_request=provider_certificate.certificate_signing_request, - is_ca=provider_certificate.certificate.is_ca, - ): - logger.debug("Certificate requested for different attributes - Skipping") - continue - try: - secret = self.model.get_secret(label=secret_label) - logger.debug("Setting secret with label %s", secret_label) - # Juju < 3.6 will create a new revision even if the content is the same - if secret.get_content(refresh=True).get("certificate", "") == str( - provider_certificate.certificate - ): - logger.debug( - "Secret %s with correct certificate already exists", secret_label - ) - continue - secret.set_content( - content={ - "certificate": str(provider_certificate.certificate), - "csr": str(provider_certificate.certificate_signing_request), - } - ) - secret.set_info( - expire=calculate_relative_datetime( - target_time=provider_certificate.certificate.expiry_time, - fraction=self.renewal_relative_time, - ), - ) - secret.get_content(refresh=True) - except SecretNotFoundError: - logger.debug("Creating new secret with label %s", secret_label) - secret = self.charm.unit.add_secret( - content={ - "certificate": str(provider_certificate.certificate), - "csr": str(provider_certificate.certificate_signing_request), - }, - label=secret_label, - expire=calculate_relative_datetime( - target_time=provider_certificate.certificate.expiry_time, - fraction=self.renewal_relative_time, - ), - ) - self.on.certificate_available.emit( - certificate_signing_request=provider_certificate.certificate_signing_request, - certificate=provider_certificate.certificate, - ca=provider_certificate.ca, - chain=provider_certificate.chain, - ) - - def _cleanup_certificate_requests(self): - """Clean up certificate requests. - - Remove any certificate requests that falls into one of the following categories: - - The CSR attributes do not match any of the certificate requests defined in - the charm's certificate_requests attribute. - - The CSR public key does not match the private key. - """ - for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if not self._csr_matches_certificate_request( - certificate_signing_request=requirer_csr.certificate_signing_request, - is_ca=requirer_csr.is_ca, - ): - self._remove_requirer_csr_from_relation_data( - requirer_csr.certificate_signing_request - ) - logger.info( - "Removed CSR from relation data because it did not match any certificate request" # noqa: E501 - ) - elif ( - self.private_key - and not requirer_csr.certificate_signing_request.matches_private_key( - self.private_key - ) - ): - self._remove_requirer_csr_from_relation_data( - requirer_csr.certificate_signing_request - ) - logger.info( - "Removed CSR from relation data because it did not match the private key" # noqa: E501 - ) - - def _tls_relation_created(self) -> bool: - relation = self.model.get_relation(self.relationship_name) - if not relation: - return False - return True - - def _get_private_key_secret_label(self) -> str: - if self.mode == Mode.UNIT: - return f"{LIBID}-private-key-{self._get_unit_number()}-{self.relationship_name}" - elif self.mode == Mode.APP: - return f"{LIBID}-private-key-{self.relationship_name}" - else: - raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.") - - def _get_csr_secret_label(self, csr: CertificateSigningRequest) -> str: - csr_in_sha256_hex = csr.get_sha256_hex() - if self.mode == Mode.UNIT: - return f"{LIBID}-certificate-{self._get_unit_number()}-{csr_in_sha256_hex}" - elif self.mode == Mode.APP: - return f"{LIBID}-certificate-{csr_in_sha256_hex}" - else: - raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.") - - def _get_unit_number(self) -> str: - return self.model.unit.name.split("/")[1] - - -class TLSCertificatesProvidesV4(Object): - """TLS certificates provider class to be instantiated by TLS certificates providers.""" - - def __init__(self, charm: CharmBase, relationship_name: str): - super().__init__(charm, relationship_name) - self.framework.observe(charm.on[relationship_name].relation_joined, self._configure) - self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) - self.framework.observe(charm.on.update_status, self._configure) - self.charm = charm - self.relationship_name = relationship_name - self._security_logger = _OWASPLogger(application=f"tls-certificates-{charm.app.name}") - - def _configure(self, _: EventBase) -> None: - """Handle update status and tls relation changed events. - - This is a common hook triggered on a regular basis. - - Revoke certificates for which no csr exists - """ - if not self.model.unit.is_leader(): - return - self._remove_certificates_for_which_no_csr_exists() - - def _remove_certificates_for_which_no_csr_exists(self) -> None: - provider_certificates = self.get_provider_certificates() - requirer_csrs = [ - request.certificate_signing_request for request in self.get_certificate_requests() - ] - for provider_certificate in provider_certificates: - if provider_certificate.certificate_signing_request not in requirer_csrs: - tls_relation = self._get_tls_relations( - relation_id=provider_certificate.relation_id - ) - self._remove_provider_certificate( - certificate=provider_certificate.certificate, - relation=tls_relation[0], - ) - - def _get_tls_relations(self, relation_id: Optional[int] = None) -> List[Relation]: - return ( - [ - relation - for relation in self.model.relations[self.relationship_name] - if relation.id == relation_id - ] - if relation_id is not None - else self.model.relations.get(self.relationship_name, []) - ) - - def get_certificate_requests( - self, relation_id: Optional[int] = None - ) -> List[RequirerCertificateRequest]: - """Load certificate requests from the relation data.""" - relations = self._get_tls_relations(relation_id) - requirer_csrs: List[RequirerCertificateRequest] = [] - for relation in relations: - for unit in relation.units: - requirer_csrs.extend(self._load_requirer_databag(relation, unit)) - requirer_csrs.extend(self._load_requirer_databag(relation, relation.app)) - return requirer_csrs - - def _load_requirer_databag( - self, relation: Relation, unit_or_app: Union[Application, Unit] - ) -> List[RequirerCertificateRequest]: - try: - requirer_relation_data = _RequirerData.load(relation.data.get(unit_or_app, {})) - except DataValidationError: - logger.debug("Invalid requirer relation data for %s", unit_or_app.name) - return [] - return [ - RequirerCertificateRequest( - relation_id=relation.id, - certificate_signing_request=CertificateSigningRequest.from_string( - csr.certificate_signing_request - ), - is_ca=csr.ca if csr.ca else False, - ) - for csr in requirer_relation_data.certificate_signing_requests - ] - - def _add_provider_certificate( - self, - relation: Relation, - provider_certificate: ProviderCertificate, - ) -> None: - chain = [str(certificate) for certificate in provider_certificate.chain] - if chain[0] != str(provider_certificate.certificate): - logger.warning( - "The order of the chain from the TLS Certificates Provider is incorrect. " - "The leaf certificate should be the first element of the chain." - ) - elif not chain_has_valid_order(chain): - logger.warning( - "The order of the chain from the TLS Certificates Provider is partially incorrect." - ) - new_certificate = _Certificate( - certificate=str(provider_certificate.certificate), - certificate_signing_request=str(provider_certificate.certificate_signing_request), - ca=str(provider_certificate.ca), - chain=chain, - ) - provider_certificates = self._load_provider_certificates(relation) - if new_certificate in provider_certificates: - logger.info("Certificate already in relation data - Doing nothing") - return - provider_certificates.append(new_certificate) - self._dump_provider_certificates(relation=relation, certificates=provider_certificates) - - def _load_provider_certificates(self, relation: Relation) -> List[_Certificate]: - try: - provider_relation_data = _ProviderApplicationData.load(relation.data[self.charm.app]) - except DataValidationError: - logger.debug("Invalid provider relation data") - return [] - return copy.deepcopy(provider_relation_data.certificates) - - def _dump_provider_certificates(self, relation: Relation, certificates: List[_Certificate]): - try: - _ProviderApplicationData(certificates=certificates).dump(relation.data[self.model.app]) - logger.info("Certificate relation data updated") - except ModelError: - logger.warning("Failed to update relation data") - - def _remove_provider_certificate( - self, - relation: Relation, - certificate: Optional[Certificate] = None, - certificate_signing_request: Optional[CertificateSigningRequest] = None, - ) -> None: - """Remove certificate based on certificate or certificate signing request.""" - provider_certificates = self._load_provider_certificates(relation) - for provider_certificate in provider_certificates: - if certificate and provider_certificate.certificate == str(certificate): - provider_certificates.remove(provider_certificate) - if ( - certificate_signing_request - and provider_certificate.certificate_signing_request - == str(certificate_signing_request) - ): - provider_certificates.remove(provider_certificate) - self._dump_provider_certificates(relation=relation, certificates=provider_certificates) - - def revoke_all_certificates(self) -> None: - """Revoke all certificates of this provider. - - This method is meant to be used when the Root CA has changed. - """ - if not self.model.unit.is_leader(): - logger.warning("Unit is not a leader - will not set relation data") - return - relations = self._get_tls_relations() - for relation in relations: - provider_certificates = self._load_provider_certificates(relation) - for certificate in provider_certificates: - certificate.revoked = True - self._dump_provider_certificates(relation=relation, certificates=provider_certificates) - self._security_logger.log_event( - event="all_certificates_revoked", - level=logging.WARNING, - description="All certificates revoked", - ) - - def set_relation_certificate( - self, - provider_certificate: ProviderCertificate, - ) -> None: - """Add certificates to relation data. - - Args: - provider_certificate (ProviderCertificate): ProviderCertificate object - - Returns: - None - """ - if not self.model.unit.is_leader(): - logger.warning("Unit is not a leader - will not set relation data") - return - certificates_relation = self.model.get_relation( - relation_name=self.relationship_name, relation_id=provider_certificate.relation_id - ) - if not certificates_relation: - raise TLSCertificatesError(f"Relation {self.relationship_name} does not exist") - self._remove_provider_certificate( - relation=certificates_relation, - certificate_signing_request=provider_certificate.certificate_signing_request, - ) - self._add_provider_certificate( - relation=certificates_relation, - provider_certificate=provider_certificate, - ) - self._security_logger.log_event( - event="certificate_provided", - level=logging.INFO, - description="Certificate provided to requirer", - relation_id=str(provider_certificate.relation_id), - common_name=provider_certificate.certificate.common_name, - ) - - def get_issued_certificates( - self, relation_id: Optional[int] = None - ) -> List[ProviderCertificate]: - """Return a List of issued (non revoked) certificates. - - Returns: - List: List of ProviderCertificate objects - """ - if not self.model.unit.is_leader(): - logger.warning("Unit is not a leader - will not read relation data") - return [] - provider_certificates = self.get_provider_certificates(relation_id=relation_id) - return [certificate for certificate in provider_certificates if not certificate.revoked] - - def get_provider_certificates( - self, relation_id: Optional[int] = None - ) -> List[ProviderCertificate]: - """Return a List of issued certificates.""" - certificates: List[ProviderCertificate] = [] - relations = self._get_tls_relations(relation_id) - for relation in relations: - if not relation.app: - logger.warning("Relation %s does not have an application", relation.id) - continue - for certificate in self._load_provider_certificates(relation): - certificates.append(certificate.to_provider_certificate(relation_id=relation.id)) - return certificates - - def get_unsolicited_certificates( - self, relation_id: Optional[int] = None - ) -> List[ProviderCertificate]: - """Return provider certificates for which no certificate requests exists. - - Those certificates should be revoked. - """ - unsolicited_certificates: List[ProviderCertificate] = [] - provider_certificates = self.get_provider_certificates(relation_id=relation_id) - requirer_csrs = self.get_certificate_requests(relation_id=relation_id) - list_of_csrs = [csr.certificate_signing_request for csr in requirer_csrs] - for certificate in provider_certificates: - if certificate.certificate_signing_request not in list_of_csrs: - unsolicited_certificates.append(certificate) - return unsolicited_certificates - - def get_outstanding_certificate_requests( - self, relation_id: Optional[int] = None - ) -> List[RequirerCertificateRequest]: - """Return CSR's for which no certificate has been issued. - - Args: - relation_id (int): Relation id - - Returns: - list: List of RequirerCertificateRequest objects. - """ - requirer_csrs = self.get_certificate_requests(relation_id=relation_id) - outstanding_csrs: List[RequirerCertificateRequest] = [] - for relation_csr in requirer_csrs: - if not self._certificate_issued_for_csr( - csr=relation_csr.certificate_signing_request, - relation_id=relation_id, - ): - outstanding_csrs.append(relation_csr) - return outstanding_csrs - - def _certificate_issued_for_csr( - self, csr: CertificateSigningRequest, relation_id: Optional[int] - ) -> bool: - """Check whether a certificate has been issued for a given CSR.""" - issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) - for issued_certificate in issued_certificates_per_csr: - if issued_certificate.certificate_signing_request == csr: - return csr.matches_certificate(issued_certificate.certificate) - return False diff --git a/dovecot-charm/pyproject.toml b/dovecot-charm/pyproject.toml index ca14524..f427468 100644 --- a/dovecot-charm/pyproject.toml +++ b/dovecot-charm/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "pydantic[email]>=2.12.5", "cryptography>=46.0.6", "charmlibs-systemd==1.0.0", + "charmlibs-interfaces-tls-certificates>=1.8.1", ] [dependency-groups] diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index a6bb3fa..9588591 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -14,15 +14,15 @@ import ops from charmhelpers.core import host from charmlibs import apt, systemd -from ops.charm import CharmBase -from ops.main import main -from ops.model import BlockedStatus, MaintenanceStatus - from charmlibs.interfaces.tls_certificates import ( CertificateAvailableEvent, CertificateRequestAttributes, TLSCertificatesRequiresV4, ) +from ops.charm import CharmBase +from ops.main import main +from ops.model import ActiveStatus, BlockedStatus, MaintenanceStatus + from constants import ( DOVECOT_CONF_TARGET, DOVECOT_CONF_TEMPLATE, diff --git a/dovecot-charm/src/charms/tls_certificates_interface/v4/tls_certificates.py b/dovecot-charm/src/charms/tls_certificates_interface/v4/tls_certificates.py deleted file mode 100644 index b779c7c..0000000 --- a/dovecot-charm/src/charms/tls_certificates_interface/v4/tls_certificates.py +++ /dev/null @@ -1,2525 +0,0 @@ -# Copyright 2024 Canonical Ltd. -# See LICENSE file for licensing details. - -"""Legacy Charmhub-hosted lib, deprecated in favour of ``charmlibs.interfaces.tls_certificates``. - -WARNING: This library is deprecated. -It will not receive feature updates or bugfixes. -``charmlibs.interfaces.tls_certificates`` 1.0 is a bug-for-bug compatible migration of this library. - -To migrate: -1. Add 'charmlibs-interfaces-tls-certificates~=1.0' to your charm's dependencies, - and remove this Charmhub-hosted library from your charm. -2. You can also remove any dependencies added to your charm only because of this library. -3. Replace `from charms.tls_certificates_interface.v4 import tls_certificates` - with `from charmlibs.interfaces import tls_certificates`. - -Read more: -- https://documentation.ubuntu.com/charmlibs -- https://pypi.org/project/charmlibs-interfaces-tls-certificates - ---- - -Charm library for managing TLS certificates (V4). - -This library contains the Requires and Provides classes for handling the tls-certificates -interface. - -Pre-requisites: - - Juju >= 3.0 - - cryptography >= 43.0.0 - - pydantic >= 1.0 - -Learn more on how-to use the TLS Certificates interface library by reading the documentation: -- https://charmhub.io/tls-certificates-interface/ - -""" # noqa: D214, D405, D411, D416 - -import copy -import ipaddress -import json -import logging -import uuid -import warnings -from contextlib import suppress -from dataclasses import asdict, dataclass, field -from datetime import datetime, timedelta, timezone -from enum import Enum -from typing import ( - Collection, - Dict, - FrozenSet, - List, - MutableMapping, - Optional, - Set, - Tuple, - Union, -) - -import pydantic -from cryptography import x509 -from cryptography.exceptions import InvalidSignature -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.asymmetric.types import CertificateIssuerPrivateKeyTypes -from cryptography.x509.oid import ExtensionOID, NameOID -from ops import BoundEvent, CharmBase, CharmEvents, Secret, SecretExpiredEvent, SecretRemoveEvent -from ops.framework import EventBase, EventSource, Handle, Object -from ops.jujuversion import JujuVersion -from ops.model import Application, ModelError, Relation, SecretNotFoundError, Unit - -# The unique Charmhub library identifier, never change it -LIBID = "afd8c2bccf834997afce12c2706d2ede" - -# Increment this major API version when introducing breaking changes -LIBAPI = 4 - -# Increment this PATCH version before using `charmcraft publish-lib` or reset -# to 0 if you are raising the major API version -LIBPATCH = 27 - -PYDEPS = [ - "cryptography>=43.0.0", - "pydantic", -] -IS_PYDANTIC_V1 = int(pydantic.version.VERSION.split(".")[0]) < 2 - -logger = logging.getLogger(__name__) - -NESTED_JSON_KEY = "owasp_event" - - -@dataclass -class _OWASPLogEvent: - """OWASP-compliant log event.""" - - datetime: str - event: str - level: str - description: str - type: str = "security" - labels: Dict[str, str] = field(default_factory=dict) - - def to_json(self) -> str: - return json.dumps(self.to_dict(), ensure_ascii=False) - - def to_dict(self) -> Dict: - log_event = dict(asdict(self), **self.labels) - log_event.pop("labels", None) - return {k: v for k, v in log_event.items() if v is not None} - - -class _OWASPLogger: - """OWASP-compliant logger for security events.""" - - def __init__(self, application: Optional[str] = None): - self.application = application - self._logger = logging.getLogger(__name__) - - def log_event(self, event: str, level: int, description: str, **labels: str): - if self.application and "application" not in labels: - labels["application"] = self.application - log = _OWASPLogEvent( - datetime=datetime.now(timezone.utc).astimezone().isoformat(), - event=event, - level=logging.getLevelName(level), - description=description, - labels=labels, - ) - self._logger.log(level, log.to_json(), extra={NESTED_JSON_KEY: log.to_dict()}) - - -class TLSCertificatesError(Exception): - """Base class for custom errors raised by this library.""" - - -class DataValidationError(TLSCertificatesError): - """Raised when data validation fails.""" - - -class _DatabagModel(pydantic.BaseModel): - """Base databag model. - - Supports both pydantic v1 and v2. - """ - - if IS_PYDANTIC_V1: - - class Config: - """Pydantic config.""" - - # ignore any extra fields in the databag - extra = "ignore" - """Ignore any extra fields in the databag.""" - allow_population_by_field_name = True - """Allow instantiating this class by field name (instead of forcing alias).""" - - _NEST_UNDER = None - - model_config = pydantic.ConfigDict( - # tolerate additional keys in databag - extra="ignore", - # Allow instantiating this class by field name (instead of forcing alias). - populate_by_name=True, - # Custom config key: whether to nest the whole datastructure (as json) - # under a field or spread it out at the toplevel. - _NEST_UNDER=None, - ) # type: ignore - """Pydantic config.""" - - @classmethod - def load(cls, databag: MutableMapping): - """Load this model from a Juju databag.""" - if IS_PYDANTIC_V1: - return cls._load_v1(databag) - nest_under = cls.model_config.get("_NEST_UNDER") - if nest_under: - return cls.model_validate(json.loads(databag[nest_under])) - - try: - data = { - k: json.loads(v) - for k, v in databag.items() - # Don't attempt to parse model-external values - if k in {(f.alias or n) for n, f in cls.model_fields.items()} - } - except json.JSONDecodeError as e: - msg = f"invalid databag contents: expecting json. {databag}" - logger.error(msg) - raise DataValidationError(msg) from e - - try: - return cls.model_validate_json(json.dumps(data)) - except pydantic.ValidationError as e: - msg = f"failed to validate databag: {databag}" - logger.debug(msg, exc_info=True) - raise DataValidationError(msg) from e - - @classmethod - def _load_v1(cls, databag: MutableMapping): - """Load implementation for pydantic v1.""" - if cls._NEST_UNDER: - return cls.parse_obj(json.loads(databag[cls._NEST_UNDER])) - - try: - data = { - k: json.loads(v) - for k, v in databag.items() - # Don't attempt to parse model-external values - if k in {f.alias for f in cls.__fields__.values()} - } - except json.JSONDecodeError as e: - msg = f"invalid databag contents: expecting json. {databag}" - logger.error(msg) - raise DataValidationError(msg) from e - - try: - return cls.parse_raw(json.dumps(data)) # type: ignore - except pydantic.ValidationError as e: - msg = f"failed to validate databag: {databag}" - logger.debug(msg, exc_info=True) - raise DataValidationError(msg) from e - - def dump(self, databag: Optional[MutableMapping] = None, clear: bool = True): - """Write the contents of this model to Juju databag. - - Args: - databag: The databag to write to. - clear: Whether to clear the databag before writing. - - Returns: - MutableMapping: The databag. - """ - if IS_PYDANTIC_V1: - return self._dump_v1(databag, clear) - if clear and databag: - databag.clear() - - if databag is None: - databag = {} - nest_under = self.model_config.get("_NEST_UNDER") - if nest_under: - databag[nest_under] = self.model_dump_json( - by_alias=True, - # skip keys whose values are default - exclude_defaults=True, - ) - return databag - - dct = self.model_dump(mode="json", by_alias=True, exclude_defaults=True) - databag.update({k: json.dumps(v) for k, v in dct.items()}) - return databag - - def _dump_v1(self, databag: Optional[MutableMapping] = None, clear: bool = True): - """Dump implementation for pydantic v1.""" - if clear and databag: - databag.clear() - - if databag is None: - databag = {} - - if self._NEST_UNDER: - databag[self._NEST_UNDER] = self.json(by_alias=True, exclude_defaults=True) - return databag - - dct = json.loads(self.json(by_alias=True, exclude_defaults=True)) - databag.update({k: json.dumps(v) for k, v in dct.items()}) - - return databag - - -class _Certificate(pydantic.BaseModel): - """Certificate model.""" - - ca: str - certificate_signing_request: str - certificate: str - chain: Optional[List[str]] = None - revoked: Optional[bool] = None - - def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate": - """Convert to a ProviderCertificate.""" - return ProviderCertificate( - relation_id=relation_id, - certificate=Certificate.from_string(self.certificate), - certificate_signing_request=CertificateSigningRequest.from_string( - self.certificate_signing_request - ), - ca=Certificate.from_string(self.ca), - chain=[Certificate.from_string(certificate) for certificate in self.chain] - if self.chain - else [], - revoked=self.revoked, - ) - - -class _CertificateSigningRequest(pydantic.BaseModel): - """Certificate signing request model.""" - - certificate_signing_request: str - ca: Optional[bool] - - -class _ProviderApplicationData(_DatabagModel): - """Provider application data model.""" - - certificates: List[_Certificate] = [] - - -class _RequirerData(_DatabagModel): - """Requirer data model. - - The same model is used for the unit and application data. - """ - - certificate_signing_requests: List[_CertificateSigningRequest] = [] - - -class Mode(Enum): - """Enum representing the mode of the certificate request. - - UNIT (default): Request a certificate for the unit. - Each unit will manage its private key, - certificate signing request and certificate. - APP: Request a certificate for the application. - Only the leader unit will manage the private key, certificate signing request - and certificate. - """ - - UNIT = 1 - APP = 2 - - -class PrivateKey: - """This class represents a private key.""" - - def __init__( - self, raw: Optional[str] = None, x509_object: Optional[rsa.RSAPrivateKey] = None - ) -> None: - """Initialize the PrivateKey object. - - If both raw and x509_object are provided, x509_object takes precedence. - """ - if x509_object: - self._private_key = x509_object - elif raw: - self._private_key = serialization.load_pem_private_key( - raw.encode(), - password=None, - ) - else: - raise ValueError("Either raw private key string or x509_object must be provided") - - @property - def raw(self) -> str: - """Return the PEM-formatted string representation of the private key.""" - return str(self) - - def __str__(self): - """Return the private key as a string in PEM format.""" - return ( - self._private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ) - .decode() - .strip() - ) - - def __hash__(self): - """Return the hash of the private key.""" - return hash(self.raw) - - @classmethod - def from_string(cls, private_key: str) -> "PrivateKey": - """Create a PrivateKey object from a private key.""" - return cls(raw=private_key) - - def is_valid(self) -> bool: - """Validate that the private key is PEM-formatted, RSA, and at least 2048 bits.""" - try: - if not isinstance(self._private_key, rsa.RSAPrivateKey): - logger.warning("Private key is not an RSA key") - return False - - if self._private_key.key_size < 2048: - logger.warning("RSA key size is less than 2048 bits") - return False - - return True - except ValueError: - logger.warning("Invalid private key format") - return False - - @classmethod - def generate(cls, key_size: int = 2048, public_exponent: int = 65537) -> "PrivateKey": - """Generate a new RSA private key. - - Args: - key_size: The size of the key in bits. - public_exponent: The public exponent of the key. - - Returns: - PrivateKey: The generated private key. - """ - private_key = rsa.generate_private_key( - public_exponent=public_exponent, - key_size=key_size, - ) - _OWASPLogger().log_event( - event="private_key_generated", - level=logging.INFO, - description="Private key generated", - key_size=str(key_size), - ) - return PrivateKey(x509_object=private_key) - - def __eq__(self, other: object) -> bool: - """Check if two PrivateKey objects are equal.""" - if not isinstance(other, PrivateKey): - return NotImplemented - return self.raw == other.raw - - -class Certificate: - """This class represents a certificate.""" - - _cert: x509.Certificate - - def __init__( - self, - raw: Optional[str] = None, # Must remain first argument for backwards compatibility - # Old Interface fields (ignored) - common_name: Optional[str] = None, - expiry_time: Optional[datetime] = None, - validity_start_time: Optional[datetime] = None, - is_ca: Optional[bool] = None, - sans_dns: Optional[Set[str]] = None, - sans_ip: Optional[Set[str]] = None, - sans_oid: Optional[Set[str]] = None, - email_address: Optional[str] = None, - organization: Optional[str] = None, - organizational_unit: Optional[str] = None, - country_name: Optional[str] = None, - state_or_province_name: Optional[str] = None, - locality_name: Optional[str] = None, - # End Old Interface fields - x509_object: Optional[x509.Certificate] = None, - ) -> None: - """Initialize the Certificate object. - - This initializer must maintain the old interface while also allowing - instantiation from an existing x509_object. It ignores all fields - other than raw and x509_object, preferring x509_object. - """ - if x509_object: - self._cert = x509_object - elif raw: - self._cert = x509.load_pem_x509_certificate(data=raw.encode()) - else: - raise ValueError("Either raw certificate string or x509_object must be provided") - - @property - def raw(self) -> str: - """Return the PEM-formatted string representation of the certificate.""" - return str(self) - - @property - def common_name(self) -> str: - """Return the common name of the certificate.""" - # We maintain compatibility with the old interface by returning - # an empty string if no common name is set. - common_name = self._cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME) - return str(common_name[0].value) if common_name else "" - - @property - def expiry_time(self) -> datetime: - """Return the expiry time of the certificate.""" - return self._cert.not_valid_after_utc - - @property - def validity_start_time(self) -> datetime: - """Return the validity start time of the certificate.""" - return self._cert.not_valid_before_utc - - @property - def is_ca(self) -> bool: - """Return whether the certificate is a CA certificate.""" - try: - return self._cert.extensions.get_extension_for_oid( - ExtensionOID.BASIC_CONSTRAINTS - ).value.ca # type: ignore[reportAttributeAccessIssue] - except x509.ExtensionNotFound: - return False - - @property - def sans_dns(self) -> Optional[Set[str]]: - """Return the DNS Subject Alternative Names of the certificate.""" - with suppress(x509.ExtensionNotFound): - sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - return {str(san) for san in sans.get_values_for_type(x509.DNSName)} - return None - - @property - def sans_ip(self) -> Optional[Set[str]]: - """Return the IP Subject Alternative Names of the certificate.""" - with suppress(x509.ExtensionNotFound): - sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - return {str(san) for san in sans.get_values_for_type(x509.IPAddress)} - return None - - @property - def sans_oid(self) -> Optional[Set[str]]: - """Return the OID Subject Alternative Names of the certificate.""" - with suppress(x509.ExtensionNotFound): - sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - return {str(san.dotted_string) for san in sans.get_values_for_type(x509.RegisteredID)} - return None - - @property - def email_address(self) -> Optional[str]: - """Return the email address of the certificate.""" - email_address = self._cert.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) - return str(email_address[0].value) if email_address else None - - @property - def organization(self) -> Optional[str]: - """Return the organization name of the certificate.""" - organization = self._cert.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) - return str(organization[0].value) if organization else None - - @property - def organizational_unit(self) -> Optional[str]: - """Return the organizational unit name of the certificate.""" - organizational_unit = self._cert.subject.get_attributes_for_oid( - NameOID.ORGANIZATIONAL_UNIT_NAME - ) - return str(organizational_unit[0].value) if organizational_unit else None - - @property - def country_name(self) -> Optional[str]: - """Return the country name of the certificate.""" - country_name = self._cert.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) - return str(country_name[0].value) if country_name else None - - @property - def state_or_province_name(self) -> Optional[str]: - """Return the state or province name of the certificate.""" - state_or_province_name = self._cert.subject.get_attributes_for_oid( - NameOID.STATE_OR_PROVINCE_NAME - ) - return str(state_or_province_name[0].value) if state_or_province_name else None - - @property - def locality_name(self) -> Optional[str]: - """Return the locality name of the certificate.""" - locality_name = self._cert.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) - return str(locality_name[0].value) if locality_name else None - - def __str__(self) -> str: - """Return the certificate as a string.""" - return self._cert.public_bytes(serialization.Encoding.PEM).decode().strip() - - def __eq__(self, other: object) -> bool: - """Check if two Certificate objects are equal.""" - if not isinstance(other, Certificate): - return NotImplemented - return self.raw == other.raw - - @classmethod - def from_string(cls, certificate: str) -> "Certificate": - """Create a Certificate object from a certificate.""" - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - except ValueError as e: - logger.error("Could not load certificate: %s", e) - raise TLSCertificatesError("Could not load certificate") - - return cls(x509_object=certificate_object) - - def matches_private_key(self, private_key: PrivateKey) -> bool: - """Check if this certificate matches a given private key. - - Args: - private_key (PrivateKey): The private key to validate against. - - Returns: - bool: True if the certificate matches the private key, False otherwise. - """ - try: - cert_public_key = self._cert.public_key() - key_public_key = private_key._private_key.public_key() - - if not isinstance(cert_public_key, rsa.RSAPublicKey): - logger.warning("Certificate does not use RSA public key") - return False - - if not isinstance(key_public_key, rsa.RSAPublicKey): - logger.warning("Private key is not an RSA key") - return False - - return cert_public_key.public_numbers() == key_public_key.public_numbers() - except Exception as e: - logger.warning("Failed to validate certificate and private key match: %s", e) - return False - - @classmethod - def generate( - cls, - csr: "CertificateSigningRequest", - ca: "Certificate", - ca_private_key: "PrivateKey", - validity: timedelta, - is_ca: bool = False, - ) -> "Certificate": - """Generate a certificate from a CSR signed by the given CA and CA private key. - - Args: - csr: The certificate signing request. - ca: The CA certificate. - ca_private_key: The CA private key. - validity: The validity period of the certificate. - is_ca: Whether the generated certificate is a CA certificate. - - Returns: - Certificate: The generated certificate. - """ - # Ideally, this would be the constructor, but we can't add new - # required parameters to the constructor without breaking backwards - # compatibility. - private_key = serialization.load_pem_private_key( - str(ca_private_key).encode(), password=None - ) - assert isinstance(private_key, CertificateIssuerPrivateKeyTypes) - - # Create a certificate builder - cert_builder = x509.CertificateBuilder( - subject_name=csr._csr.subject, - # issuer_name=ca._cert.subject, # TODO: Validate this is correct, the old code used `issuer` - issuer_name=ca._cert.issuer, - public_key=csr._csr.public_key(), - serial_number=x509.random_serial_number(), - not_valid_before=datetime.now(timezone.utc), - not_valid_after=datetime.now(timezone.utc) + validity, - ) - extensions = _generate_certificate_request_extensions( - authority_key_identifier=ca._cert.extensions.get_extension_for_class( - x509.SubjectKeyIdentifier - ).value.key_identifier, - csr=csr._csr, - is_ca=is_ca, - ) - for extension in extensions: - try: - cert_builder = cert_builder.add_extension(extension.value, extension.critical) - except ValueError as e: - logger.error("Could not add extension to certificate: %s", e) - raise TLSCertificatesError("Could not add extension to certificate") from e - - # Sign the certificate with the CA's private key - cert = cert_builder.sign(private_key=private_key, algorithm=hashes.SHA256()) - _OWASPLogger().log_event( - event="certificate_generated", - level=logging.INFO, - description="Certificate generated from CSR", - common_name=csr.common_name, - is_ca=str(is_ca), - validity_days=str(validity.days), - ) - - return cls(x509_object=cert) - - @classmethod - def generate_self_signed_ca( - cls, - attributes: "CertificateRequestAttributes", - private_key: PrivateKey, - validity: timedelta, - ) -> "Certificate": - """Generate a self-signed CA certificate. - - Args: - attributes: The certificate request attributes. - private_key: The private key to sign the CA certificate. - validity: The validity period of the CA certificate. - - Returns: - Certificate: The generated CA certificate. - """ - assert isinstance(private_key._private_key, rsa.RSAPrivateKey) - - public_key = private_key._private_key.public_key() - - builder = x509.CertificateBuilder( - public_key=public_key, - serial_number=x509.random_serial_number(), - not_valid_before=datetime.now(timezone.utc), - not_valid_after=datetime.now(timezone.utc) + validity, - ) - - if subject_name := _extract_subject_name_attributes(attributes): - builder = builder.subject_name(subject_name).issuer_name(subject_name) - - builder = ( - builder.add_extension( - x509.SubjectKeyIdentifier.from_public_key(public_key), critical=False - ) - .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True) - .add_extension( - x509.KeyUsage( - digital_signature=True, - key_encipherment=True, - key_cert_sign=True, - key_agreement=False, - content_commitment=False, - data_encipherment=False, - crl_sign=False, - encipher_only=False, - decipher_only=False, - ), - critical=True, - ) - ) - - if san_extension := _san_extension( - email_address=attributes.email_address, - sans_dns=attributes.sans_dns, - sans_ip=attributes.sans_ip, - sans_oid=attributes.sans_oid, - ): - builder = builder.add_extension(san_extension, critical=False) - - cert = cls(x509_object=builder.sign(private_key._private_key, algorithm=hashes.SHA256())) - - _OWASPLogger().log_event( - event="ca_certificate_generated", - level=logging.INFO, - description="CA certificate generated", - common_name=cert.common_name, - validity_days=str(validity.days), - ) - - return cert - - def __hash__(self): - """Return the hash of the private key.""" - return hash(self.raw) - - -class CertificateSigningRequest: - """A representation of the certificate signing request.""" - - _csr: x509.CertificateSigningRequest - - def __init__( - self, - raw: Optional[str] = None, # Must remain first argument for backwards compatibility - # Old Interface fields (ignored) - common_name: Optional[str] = None, - sans_dns: Optional[Set[str]] = None, - sans_ip: Optional[Set[str]] = None, - sans_oid: Optional[Set[str]] = None, - email_address: Optional[str] = None, - organization: Optional[str] = None, - organizational_unit: Optional[str] = None, - country_name: Optional[str] = None, - state_or_province_name: Optional[str] = None, - locality_name: Optional[str] = None, - has_unique_identifier: Optional[bool] = None, - # End Old Interface fields - x509_object: Optional[x509.CertificateSigningRequest] = None, - ): - """Initialize the CertificateSigningRequest object. - - This initializer must maintain the old interface while also allowing - instantiation from an existing x509_object. It ignores all fields - other than raw and x509_object, preferring x509_object. - """ - if x509_object: - self._csr = x509_object - return - elif raw: - try: - self._csr = x509.load_pem_x509_csr(raw.encode()) - except ValueError as e: - logger.error("Could not load CSR: %s", e) - raise TLSCertificatesError("Could not load CSR") - return - raise ValueError("Either raw CSR string or x509_object must be provided") - - @property - def common_name(self) -> str: - """Return the common name of the CSR.""" - common_name = self._csr.subject.get_attributes_for_oid(NameOID.COMMON_NAME) - return str(common_name[0].value) if common_name else "" - - @property - def sans_dns(self) -> Set[str]: - """Return the DNS Subject Alternative Names of the CSR.""" - with suppress(x509.ExtensionNotFound): - sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - return {str(san) for san in sans.get_values_for_type(x509.DNSName)} - return set() - - @property - def sans_ip(self) -> Set[str]: - """Return the IP Subject Alternative Names of the CSR.""" - with suppress(x509.ExtensionNotFound): - sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - return {str(san) for san in sans.get_values_for_type(x509.IPAddress)} - return set() - - @property - def sans_oid(self) -> Set[str]: - """Return the OID Subject Alternative Names of the CSR.""" - with suppress(x509.ExtensionNotFound): - sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - return {str(san.dotted_string) for san in sans.get_values_for_type(x509.RegisteredID)} - return set() - - @property - def email_address(self) -> Optional[str]: - """Return the email address of the CSR.""" - email_address = self._csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) - return str(email_address[0].value) if email_address else None - - @property - def organization(self) -> Optional[str]: - """Return the organization name of the CSR.""" - organization = self._csr.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) - return str(organization[0].value) if organization else None - - @property - def organizational_unit(self) -> Optional[str]: - """Return the organizational unit name of the CSR.""" - organizational_unit = self._csr.subject.get_attributes_for_oid( - NameOID.ORGANIZATIONAL_UNIT_NAME - ) - return str(organizational_unit[0].value) if organizational_unit else None - - @property - def country_name(self) -> Optional[str]: - """Return the country name of the CSR.""" - country_name = self._csr.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) - return str(country_name[0].value) if country_name else None - - @property - def state_or_province_name(self) -> Optional[str]: - """Return the state or province name of the CSR.""" - state_or_province_name = self._csr.subject.get_attributes_for_oid( - NameOID.STATE_OR_PROVINCE_NAME - ) - return str(state_or_province_name[0].value) if state_or_province_name else None - - @property - def locality_name(self) -> Optional[str]: - """Return the locality name of the CSR.""" - locality_name = self._csr.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) - return str(locality_name[0].value) if locality_name else None - - @property - def has_unique_identifier(self) -> bool: - """Return whether the CSR has a unique identifier.""" - unique_identifier = self._csr.subject.get_attributes_for_oid( - NameOID.X500_UNIQUE_IDENTIFIER - ) - return bool(unique_identifier) - - @property - def raw(self) -> str: - """Return the PEM-formatted string representation of the CSR.""" - return self.__str__() - - def __str__(self) -> str: - """Return the CSR as a string.""" - return self._csr.public_bytes(serialization.Encoding.PEM).decode().strip() - - @property - def additional_critical_extensions(self) -> List[x509.ExtensionType]: - """Return additional critical extensions present on the CSR (excluding SAN).""" - extensions: List[x509.ExtensionType] = [] - for extension in self._csr.extensions: - if extension.critical and extension.oid != ExtensionOID.SUBJECT_ALTERNATIVE_NAME: - extensions.append(extension.value) - return extensions - - @classmethod - def from_string(cls, csr: str) -> "CertificateSigningRequest": - """Create a CertificateSigningRequest object from a CSR.""" - return cls(raw=csr) - - @classmethod - def from_csr(cls, csr: x509.CertificateSigningRequest) -> "CertificateSigningRequest": - """Create a CertificateSigningRequest object from a CSR.""" - return cls(x509_object=csr) - - def __eq__(self, other: object) -> bool: - """Check if two CertificateSigningRequest objects are equal.""" - if not isinstance(other, CertificateSigningRequest): - return NotImplemented - return self.raw == other.raw - - def __hash__(self): - """Return the hash of the private key.""" - return hash(self.raw) - - def matches_certificate(self, certificate: Certificate) -> bool: - """Check if this CSR matches a given certificate. - - Args: - certificate (Certificate): The certificate to validate against. - - Returns: - bool: True if the CSR matches the certificate, False otherwise. - """ - return self._csr.public_key() == certificate._cert.public_key() - - def matches_private_key(self, key: PrivateKey) -> bool: - """Check if a CSR matches a private key. - - This function only works with RSA keys. - - Args: - key (PrivateKey): Private key - Returns: - bool: True/False depending on whether the CSR matches the private key. - """ - try: - key_object_public_key = key._private_key.public_key() - csr_object_public_key = self._csr.public_key() - if not isinstance(key_object_public_key, rsa.RSAPublicKey): - logger.warning("Key is not an RSA key") - return False - if not isinstance(csr_object_public_key, rsa.RSAPublicKey): - logger.warning("CSR is not an RSA key") - return False - if ( - csr_object_public_key.public_numbers().n - != key_object_public_key.public_numbers().n - ): - logger.warning("Public key numbers between CSR and key do not match") - return False - except ValueError: - logger.warning("Could not load certificate or CSR.") - return False - return True - - def get_sha256_hex(self) -> str: - """Calculate the hash of the provided data and return the hexadecimal representation.""" - digest = hashes.Hash(hashes.SHA256()) - digest.update(self.raw.encode()) - return digest.finalize().hex() - - def sign( - self, ca: Certificate, ca_private_key: PrivateKey, validity: timedelta, is_ca: bool = False - ) -> Certificate: - """Sign this CSR with the given CA and CA private key. - - Args: - ca: The CA certificate. - ca_private_key: The CA private key. - validity: The validity period of the certificate. - is_ca: Whether the generated certificate is a CA certificate. - - Returns: - Certificate: The signed certificate. - """ - return Certificate.generate( - csr=self, - ca=ca, - ca_private_key=ca_private_key, - validity=validity, - is_ca=is_ca, - ) - - @classmethod - def generate( - cls, - attributes: "CertificateRequestAttributes", - private_key: PrivateKey, - ) -> "CertificateSigningRequest": - """Generate a CSR using the supplied attributes and private key. - - Args: - attributes (CertificateRequestAttributes): Certificate request attributes - private_key (PrivateKey): Private key - Returns: - CertificateSigningRequest: CSR - """ - signing_key = private_key._private_key - assert isinstance(signing_key, CertificateIssuerPrivateKeyTypes) - - csr_builder = x509.CertificateSigningRequestBuilder() - if subject_name := _extract_subject_name_attributes(attributes): - csr_builder = csr_builder.subject_name(subject_name) - - _sans: List[x509.GeneralName] = [] - if attributes.sans_oid: - _sans.extend( - [x509.RegisteredID(x509.ObjectIdentifier(san)) for san in attributes.sans_oid] - ) - if attributes.sans_ip: - _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in attributes.sans_ip]) - if attributes.sans_dns: - _sans.extend([x509.DNSName(san) for san in attributes.sans_dns]) - if _sans: - csr_builder = csr_builder.add_extension( - x509.SubjectAlternativeName(set(_sans)), critical=False - ) - if attributes.additional_critical_extensions: - for extension in attributes.additional_critical_extensions: - csr_builder = csr_builder.add_extension(extension, critical=True) - signed_certificate_request = csr_builder.sign(signing_key, hashes.SHA256()) - return cls(x509_object=signed_certificate_request) - - -class CertificateRequestAttributes: - """A representation of the certificate request attributes.""" - - def __init__( - self, - common_name: Optional[str] = None, - sans_dns: Optional[Collection[str]] = None, - sans_ip: Optional[Collection[str]] = None, - sans_oid: Optional[Collection[str]] = None, - email_address: Optional[str] = None, - organization: Optional[str] = None, - organizational_unit: Optional[str] = None, - country_name: Optional[str] = None, - state_or_province_name: Optional[str] = None, - locality_name: Optional[str] = None, - is_ca: bool = False, - add_unique_id_to_subject_name: bool = True, - additional_critical_extensions: Optional[Collection[x509.ExtensionType]] = None, - ): - if not common_name and not sans_dns and not sans_ip and not sans_oid: - raise ValueError( - "At least one of common_name, sans_dns, sans_ip, or sans_oid must be provided" - ) - self._common_name = common_name - self._sans_dns = set(sans_dns) if sans_dns else None - self._sans_ip = set(sans_ip) if sans_ip else None - self._sans_oid = set(sans_oid) if sans_oid else None - self._email_address = email_address - self._organization = organization - self._organizational_unit = organizational_unit - self._country_name = country_name - self._state_or_province_name = state_or_province_name - self._locality_name = locality_name - self._is_ca = is_ca - self._add_unique_id_to_subject_name = add_unique_id_to_subject_name - self._additional_critical_extensions = list(additional_critical_extensions or []) - - @property - def common_name(self) -> str: - """Return the common name.""" - # For legacy interface compatibility, return empty string if not set - return self._common_name if self._common_name else "" - - @property - def sans_dns(self) -> Optional[Set[str]]: - """Return the DNS Subject Alternative Names.""" - return self._sans_dns - - @property - def sans_ip(self) -> Optional[Set[str]]: - """Return the IP Subject Alternative Names.""" - return self._sans_ip - - @property - def sans_oid(self) -> Optional[Set[str]]: - """Return the OID Subject Alternative Names.""" - return self._sans_oid - - @property - def email_address(self) -> Optional[str]: - """Return the email address.""" - return self._email_address - - @property - def organization(self) -> Optional[str]: - """Return the organization name.""" - return self._organization - - @property - def organizational_unit(self) -> Optional[str]: - """Return the organizational unit name.""" - return self._organizational_unit - - @property - def country_name(self) -> Optional[str]: - """Return the country name.""" - return self._country_name - - @property - def state_or_province_name(self) -> Optional[str]: - """Return the state or province name.""" - return self._state_or_province_name - - @property - def locality_name(self) -> Optional[str]: - """Return the locality name.""" - return self._locality_name - - @property - def is_ca(self) -> bool: - """Return whether the certificate is a CA certificate.""" - return self._is_ca - - @property - def add_unique_id_to_subject_name(self) -> bool: - """Return whether to add a unique identifier to the subject name.""" - return self._add_unique_id_to_subject_name - - @property - def additional_critical_extensions(self) -> List[x509.ExtensionType]: - """Return additional critical extensions to be added to the CSR.""" - return self._additional_critical_extensions - - @classmethod - def from_csr( - cls, csr: CertificateSigningRequest, is_ca: bool - ) -> "CertificateRequestAttributes": - """Create CertificateRequestAttributes from a CertificateSigningRequest. - - Args: - csr: The CSR to extract attributes from. - is_ca: Whether a CA certificate is being requested. - - Returns: - CertificateRequestAttributes: The extracted attributes. - """ - return cls( - common_name=csr.common_name, - sans_dns=csr.sans_dns, - sans_ip=csr.sans_ip, - sans_oid=csr.sans_oid, - email_address=csr.email_address, - organization=csr.organization, - organizational_unit=csr.organizational_unit, - country_name=csr.country_name, - state_or_province_name=csr.state_or_province_name, - locality_name=csr.locality_name, - is_ca=is_ca, - add_unique_id_to_subject_name=csr.has_unique_identifier, - additional_critical_extensions=csr.additional_critical_extensions, - ) - - def __eq__(self, other: object) -> bool: - """Check if two CertificateRequestAttributes objects are equal.""" - if not isinstance(other, CertificateRequestAttributes): - return NotImplemented - return ( - self.common_name == other.common_name - and self.sans_dns == other.sans_dns - and self.sans_ip == other.sans_ip - and self.sans_oid == other.sans_oid - and self.email_address == other.email_address - and self.organization == other.organization - and self.organizational_unit == other.organizational_unit - and self.country_name == other.country_name - and self.state_or_province_name == other.state_or_province_name - and self.locality_name == other.locality_name - and self.is_ca == other.is_ca - and self.add_unique_id_to_subject_name == other.add_unique_id_to_subject_name - and self.additional_critical_extensions == other.additional_critical_extensions - ) - - def is_valid(self) -> bool: - """Validate the attributes of the certificate request. - - Returns: - bool: True if the attributes are valid, False otherwise. - """ - if not self.common_name and not self.sans_dns and not self.sans_ip and not self.sans_oid: - logger.warning( - "At least one of common_name, sans_dns, sans_ip, or sans_oid must be provided" - ) - return False - return True - - def generate_csr( - self, - private_key: PrivateKey, - ) -> CertificateSigningRequest: - """Generate a CSR using the current attributes and a private key. - - Args: - private_key (PrivateKey): Private key to sign the CSR. - - Returns: - CertificateSigningRequest: The generated CSR. - """ - return CertificateSigningRequest.generate(self, private_key) - - -@dataclass(frozen=True) -class ProviderCertificate: - """This class represents a certificate provided by the TLS provider.""" - - relation_id: int - certificate: Certificate - certificate_signing_request: CertificateSigningRequest - ca: Certificate - chain: List[Certificate] - revoked: Optional[bool] = None - - def to_json(self) -> str: - """Return the object as a JSON string. - - Returns: - str: JSON representation of the object - """ - return json.dumps( - { - "csr": str(self.certificate_signing_request), - "certificate": str(self.certificate), - "ca": str(self.ca), - "chain": [str(cert) for cert in self.chain], - "revoked": self.revoked, - } - ) - - -@dataclass(frozen=True) -class RequirerCertificateRequest: - """This class represents a certificate signing request requested by a specific TLS requirer.""" - - relation_id: int - certificate_signing_request: CertificateSigningRequest - is_ca: bool - - -class CertificateAvailableEvent(EventBase): - """Charm Event triggered when a TLS certificate is available.""" - - def __init__( - self, - handle: Handle, - certificate: Certificate, - certificate_signing_request: CertificateSigningRequest, - ca: Certificate, - chain: List[Certificate], - ): - super().__init__(handle) - self.certificate = certificate - self.certificate_signing_request = certificate_signing_request - self.ca = ca - self.chain = chain - - def snapshot(self) -> dict: - """Return snapshot.""" - return { - "certificate": str(self.certificate), - "certificate_signing_request": str(self.certificate_signing_request), - "ca": str(self.ca), - "chain": json.dumps([str(certificate) for certificate in self.chain]), - } - - def restore(self, snapshot: dict): - """Restore snapshot.""" - self.certificate = Certificate.from_string(snapshot["certificate"]) - self.certificate_signing_request = CertificateSigningRequest.from_string( - snapshot["certificate_signing_request"] - ) - self.ca = Certificate.from_string(snapshot["ca"]) - chain_strs = json.loads(snapshot["chain"]) - self.chain = [Certificate.from_string(chain_str) for chain_str in chain_strs] - - def chain_as_pem(self) -> str: - """Return full certificate chain as a PEM string.""" - return "\n\n".join([str(cert) for cert in self.chain]) - - -def generate_private_key( - key_size: int = 2048, - public_exponent: int = 65537, -) -> PrivateKey: - """Generate a private key with the RSA algorithm. - - Args: - key_size (int): Key size in bits, must be at least 2048 bits - public_exponent: Public exponent. - - Returns: - PrivateKey: Private Key - """ - warnings.warn( - "generate_private_key() is deprecated. Use PrivateKey.generate() instead.", - DeprecationWarning, - ) - return PrivateKey.generate(key_size=key_size, public_exponent=public_exponent) - - -def calculate_relative_datetime(target_time: datetime, fraction: float) -> datetime: - """Calculate a datetime that is a given percentage from now to a target time. - - Args: - target_time (datetime): The future datetime to interpolate towards. - fraction (float): Fraction of the interval from now to target_time (0.0-1.0). - 1.0 means return target_time, - 0.9 means return the time after 90% of the interval has passed, - and 0.0 means return now. - """ - if fraction <= 0.0 or fraction > 1.0: - raise ValueError("Invalid fraction. Must be between 0.0 and 1.0") - now = datetime.now(timezone.utc) - time_until_target = target_time - now - return now + time_until_target * fraction - - -def chain_has_valid_order(chain: List[str]) -> bool: - """Check if the chain has a valid order. - - Validates that each certificate in the chain is properly signed by the next certificate. - The chain should be ordered from leaf to root, where each certificate is signed by - the next one in the chain. - - Args: - chain (List[str]): List of certificates in PEM format, ordered from leaf to root - - Returns: - bool: True if the chain has a valid order, False otherwise. - """ - if len(chain) < 2: - return True - - try: - for i in range(len(chain) - 1): - cert = x509.load_pem_x509_certificate(chain[i].encode()) - issuer = x509.load_pem_x509_certificate(chain[i + 1].encode()) - cert.verify_directly_issued_by(issuer) - return True - except (ValueError, TypeError, InvalidSignature): - return False - - -def generate_csr( # noqa: C901 - private_key: PrivateKey, - common_name: str, - sans_dns: Optional[FrozenSet[str]] = frozenset(), - sans_ip: Optional[FrozenSet[str]] = frozenset(), - sans_oid: Optional[FrozenSet[str]] = frozenset(), - organization: Optional[str] = None, - organizational_unit: Optional[str] = None, - email_address: Optional[str] = None, - country_name: Optional[str] = None, - locality_name: Optional[str] = None, - state_or_province_name: Optional[str] = None, - add_unique_id_to_subject_name: bool = True, -) -> CertificateSigningRequest: - """Generate a CSR using private key and subject. - - Args: - private_key (PrivateKey): Private key - common_name (str): Common name - sans_dns (FrozenSet[str]): DNS Subject Alternative Names - sans_ip (FrozenSet[str]): IP Subject Alternative Names - sans_oid (FrozenSet[str]): OID Subject Alternative Names - organization (Optional[str]): Organization name - organizational_unit (Optional[str]): Organizational unit name - email_address (Optional[str]): Email address - country_name (Optional[str]): Country name - state_or_province_name (Optional[str]): State or province name - locality_name (Optional[str]): Locality name - add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's - subject name. Always leave to "True" when the CSR is used to request certificates - using the tls-certificates relation. - - Returns: - CertificateSigningRequest: CSR - """ - warnings.warn( - "generate_csr() is deprecated. Use CertificateRequestAttributes.generate_csr() or CertificateSigningRequest.generate() instead.", - DeprecationWarning, - ) - return CertificateRequestAttributes( - common_name=common_name, - sans_dns=sans_dns, - sans_ip=sans_ip, - sans_oid=sans_oid, - organization=organization, - organizational_unit=organizational_unit, - email_address=email_address, - country_name=country_name, - state_or_province_name=state_or_province_name, - locality_name=locality_name, - add_unique_id_to_subject_name=add_unique_id_to_subject_name, - ).generate_csr(private_key=private_key) - - -def generate_ca( - private_key: PrivateKey, - validity: timedelta, - common_name: str, - sans_dns: Optional[FrozenSet[str]] = frozenset(), - sans_ip: Optional[FrozenSet[str]] = frozenset(), - sans_oid: Optional[FrozenSet[str]] = frozenset(), - organization: Optional[str] = None, - organizational_unit: Optional[str] = None, - email_address: Optional[str] = None, - country_name: Optional[str] = None, - state_or_province_name: Optional[str] = None, - locality_name: Optional[str] = None, -) -> Certificate: - """Generate a self signed CA Certificate. - - Args: - private_key: Private key - validity: Certificate validity time - common_name: Common Name that can be an IP or a Full Qualified Domain Name (FQDN). - sans_dns: DNS Subject Alternative Names - sans_ip: IP Subject Alternative Names - sans_oid: OID Subject Alternative Names - organization: Organization name - organizational_unit: Organizational unit name - email_address: Email address - country_name: Certificate Issuing country - state_or_province_name: Certificate Issuing state or province - locality_name: Certificate Issuing locality - - Returns: - CA Certificate. - """ - warnings.warn( - "generate_ca() is deprecated. Use Certificate.generate_self_signed_ca() instead.", - DeprecationWarning, - ) - attributes = CertificateRequestAttributes( - common_name=common_name, - sans_dns=sans_dns, - sans_ip=sans_ip, - sans_oid=sans_oid, - organization=organization, - organizational_unit=organizational_unit, - email_address=email_address, - country_name=country_name, - state_or_province_name=state_or_province_name, - locality_name=locality_name, - is_ca=True, - ) - return Certificate.generate_self_signed_ca(attributes, private_key, validity) - - -def _san_extension( - email_address: Optional[str] = None, - sans_dns: Optional[Collection[str]] = frozenset(), - sans_ip: Optional[Collection[str]] = frozenset(), - sans_oid: Optional[Collection[str]] = frozenset(), -) -> Optional[x509.SubjectAlternativeName]: - sans: List[x509.GeneralName] = [] - if email_address: - # If an e-mail address was provided, it should always be in the SAN - sans.append(x509.RFC822Name(email_address)) - if sans_dns: - sans.extend([x509.DNSName(san) for san in sans_dns]) - if sans_ip: - sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) - if sans_oid: - sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) - if not sans: - return None - return x509.SubjectAlternativeName(sans) - - -def generate_certificate( - csr: CertificateSigningRequest, - ca: Certificate, - ca_private_key: PrivateKey, - validity: timedelta, - is_ca: bool = False, -) -> Certificate: - """Generate a TLS certificate based on a CSR. - - Args: - csr (CertificateSigningRequest): CSR - ca (Certificate): CA Certificate - ca_private_key (PrivateKey): CA private key - validity (timedelta): Certificate validity time - is_ca (bool): Whether the certificate is a CA certificate - - Returns: - Certificate: Certificate - """ - warnings.warn( - "generate_certificate() is deprecated. Use Certificate.generate() instead.", - DeprecationWarning, - ) - return Certificate.generate( - csr=csr, - ca=ca, - ca_private_key=ca_private_key, - validity=validity, - is_ca=is_ca, - ) - - -def _extract_subject_name_attributes( - attributes: CertificateRequestAttributes, -) -> Optional[x509.Name]: - subject_name_attributes = [] - if attributes.common_name: - subject_name_attributes.append( - x509.NameAttribute(x509.NameOID.COMMON_NAME, attributes.common_name) - ) - if attributes.add_unique_id_to_subject_name: - unique_identifier = uuid.uuid4() - subject_name_attributes.append( - x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier)) - ) - if attributes.organization: - subject_name_attributes.append( - x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, attributes.organization) - ) - if attributes.organizational_unit: - subject_name_attributes.append( - x509.NameAttribute( - x509.NameOID.ORGANIZATIONAL_UNIT_NAME, - attributes.organizational_unit, - ) - ) - if attributes.email_address: - subject_name_attributes.append( - x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, attributes.email_address) - ) - if attributes.country_name: - subject_name_attributes.append( - x509.NameAttribute(x509.NameOID.COUNTRY_NAME, attributes.country_name) - ) - if attributes.state_or_province_name: - subject_name_attributes.append( - x509.NameAttribute( - x509.NameOID.STATE_OR_PROVINCE_NAME, - attributes.state_or_province_name, - ) - ) - if attributes.locality_name: - subject_name_attributes.append( - x509.NameAttribute(x509.NameOID.LOCALITY_NAME, attributes.locality_name) - ) - - if subject_name_attributes: - return x509.Name(subject_name_attributes) - - return None - - -def _generate_certificate_request_extensions( - authority_key_identifier: bytes, - csr: x509.CertificateSigningRequest, - is_ca: bool, -) -> List[x509.Extension]: - """Generate a list of certificate extensions from a CSR and other known information. - - Args: - authority_key_identifier (bytes): Authority key identifier - csr (x509.CertificateSigningRequest): CSR - is_ca (bool): Whether the certificate is a CA certificate - - Returns: - List[x509.Extension]: List of extensions - """ - cert_extensions_list: List[x509.Extension] = [ - x509.Extension( - oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, - value=x509.AuthorityKeyIdentifier( - key_identifier=authority_key_identifier, - authority_cert_issuer=None, - authority_cert_serial_number=None, - ), - critical=False, - ), - x509.Extension( - oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, - value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), - critical=False, - ), - x509.Extension( - oid=ExtensionOID.BASIC_CONSTRAINTS, - critical=True, - value=x509.BasicConstraints(ca=is_ca, path_length=None), - ), - ] - if sans := _generate_subject_alternative_name_extension(csr): - cert_extensions_list.append(sans) - - if is_ca: - cert_extensions_list.append( - x509.Extension( - ExtensionOID.KEY_USAGE, - critical=True, - value=x509.KeyUsage( - digital_signature=False, - content_commitment=False, - key_encipherment=False, - data_encipherment=False, - key_agreement=False, - key_cert_sign=True, - crl_sign=True, - encipher_only=False, - decipher_only=False, - ), - ) - ) - - existing_oids = {ext.oid for ext in cert_extensions_list} - for extension in csr.extensions: - if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: - continue - if extension.oid in existing_oids: - logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) - continue - cert_extensions_list.append(extension) - - return cert_extensions_list - - -def _generate_subject_alternative_name_extension( - csr: x509.CertificateSigningRequest, -) -> Optional[x509.Extension]: - sans: List[x509.GeneralName] = [] - try: - loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) - sans.extend( - [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] - ) - sans.extend( - [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] - ) - sans.extend( - [ - x509.RegisteredID(oid) - for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) - ] - ) - sans.extend( - [ - x509.RFC822Name(name) - for name in loaded_san_ext.value.get_values_for_type(x509.RFC822Name) - ] - ) - except x509.ExtensionNotFound: - pass - # If email is present in the CSR Subject, make sure it is also in the SANS - # to conform to RFC 5280. - email = csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) - if email: - email_rfc822 = x509.RFC822Name(str(email[0].value)) - if email_rfc822 not in sans: - sans.append(email_rfc822) - - return ( - x509.Extension( - oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, - critical=False, - value=x509.SubjectAlternativeName(sans), - ) - if sans - else None - ) - - -class CertificatesRequirerCharmEvents(CharmEvents): - """List of events that the TLS Certificates requirer charm can leverage.""" - - certificate_available = EventSource(CertificateAvailableEvent) - - -class TLSCertificatesRequiresV4(Object): - """A class to manage the TLS certificates interface for a unit or app.""" - - on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] - - def __init__( - self, - charm: CharmBase, - relationship_name: str, - certificate_requests: List[CertificateRequestAttributes], - mode: Mode = Mode.UNIT, - refresh_events: List[BoundEvent] = [], - private_key: Optional[PrivateKey] = None, - renewal_relative_time: float = 0.9, - ): - """Create a new instance of the TLSCertificatesRequiresV4 class. - - Args: - charm (CharmBase): The charm instance to relate to. - relationship_name (str): The name of the relation that provides the certificates. - certificate_requests (List[CertificateRequestAttributes]): - A list with the attributes of the certificate requests. - mode (Mode): Whether to use unit or app certificates mode. Default is Mode.UNIT. - In UNIT mode the requirer will place the csr in the unit relation data. - Each unit will manage its private key, - certificate signing request and certificate. - UNIT mode is for use cases where each unit has its own identity. - If you don't know which mode to use, you likely need UNIT. - In APP mode the leader unit will place the csr in the app relation databag. - APP mode is for use cases where the underlying application needs the certificate - for example using it as an intermediate CA to sign other certificates. - The certificate can only be accessed by the leader unit. - refresh_events (List[BoundEvent]): A list of events to trigger a refresh of - the certificates. - private_key (Optional[PrivateKey]): The private key to use for the certificates. - If provided, it will be used instead of generating a new one. - If the key is not valid an exception will be raised. - Using this parameter is discouraged, - having to pass around private keys manually can be a security concern. - Allowing the library to generate and manage the key is the more secure approach. - renewal_relative_time (float): The time to renew the certificate relative to its - expiry. - Default is 0.9, meaning 90% of the validity period. - The minimum value is 0.5, meaning 50% of the validity period. - If an invalid value is provided, an exception will be raised. - """ - super().__init__(charm, relationship_name) - if not JujuVersion.from_environ().has_secrets: - logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") - if not self._mode_is_valid(mode): - raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP") - for certificate_request in certificate_requests: - if not certificate_request.is_valid(): - raise TLSCertificatesError("Invalid certificate request") - self.charm = charm - self.relationship_name = relationship_name - self.certificate_requests = certificate_requests - self.mode = mode - if private_key and not private_key.is_valid(): - raise TLSCertificatesError("Invalid private key") - if renewal_relative_time <= 0.5 or renewal_relative_time > 1.0: - raise TLSCertificatesError( - "Invalid renewal relative time. Must be between 0.5 and 1.0" - ) - self._private_key = private_key - self.renewal_relative_time = renewal_relative_time - self.framework.observe(charm.on[relationship_name].relation_created, self._configure) - self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) - self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - self.framework.observe(charm.on.secret_remove, self._on_secret_remove) - for event in refresh_events: - self.framework.observe(event, self._configure) - self._security_logger = _OWASPLogger(application=f"tls-certificates-{charm.app.name}") - - def _configure(self, _: Optional[EventBase] = None): - """Handle TLS Certificates Relation Data. - - This method is called during any TLS relation event. - It will generate a private key if it doesn't exist yet. - It will send certificate requests if they haven't been sent yet. - It will find available certificates and emit events. - """ - if not self._tls_relation_created(): - logger.debug("TLS relation not created yet.") - return - self._ensure_private_key() - self._cleanup_certificate_requests() - self._send_certificate_requests() - self._find_available_certificates() - - def _mode_is_valid(self, mode: Mode) -> bool: - return mode in [Mode.UNIT, Mode.APP] - - def _validate_secret_exists(self, secret: Secret) -> None: - secret.get_info() # Will raise `SecretNotFoundError` if the secret does not exist - - def _on_secret_remove(self, event: SecretRemoveEvent) -> None: - """Handle Secret Removed Event.""" - try: - # Ensure the secret exists before trying to remove it, otherwise - # the unit could be stuck in an error state. See the docstring of - # `remove_revision` and the below issue for more information. - # https://github.com/juju/juju/issues/19036 - self._validate_secret_exists(event.secret) - event.secret.remove_revision(event.revision) - except SecretNotFoundError: - logger.warning( - "No such secret %s, nothing to remove", - event.secret.label or event.secret.id, - ) - return - - def _on_secret_expired(self, event: SecretExpiredEvent) -> None: - """Handle Secret Expired Event. - - Renews certificate requests and removes the expired secret. - """ - if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-certificate"): - return - try: - csr_str = event.secret.get_content(refresh=True)["csr"] - except ModelError: - logger.error("Failed to get CSR from secret - Skipping") - return - csr = CertificateSigningRequest.from_string(csr_str) - self._renew_certificate_request(csr) - event.secret.remove_all_revisions() - - def sync(self) -> None: - """Sync TLS Certificates Relation Data. - - This method allows the requirer to sync the TLS certificates relation data - without waiting for the refresh events to be triggered. - """ - self._configure() - - def renew_certificate(self, certificate: ProviderCertificate) -> None: - """Request the renewal of the provided certificate.""" - certificate_signing_request = certificate.certificate_signing_request - secret_label = self._get_csr_secret_label(certificate_signing_request) - try: - secret = self.model.get_secret(label=secret_label) - except SecretNotFoundError: - logger.warning("No matching secret found - Skipping renewal") - return - current_csr = secret.get_content(refresh=True).get("csr", "") - if current_csr != str(certificate_signing_request): - logger.warning("No matching CSR found - Skipping renewal") - return - self._renew_certificate_request(certificate_signing_request) - secret.remove_all_revisions() - - def _renew_certificate_request(self, csr: CertificateSigningRequest): - """Remove existing CSR from relation data and create a new one.""" - self._remove_requirer_csr_from_relation_data(csr) - self._send_certificate_requests() - logger.info("Renewed certificate request") - - def _remove_requirer_csr_from_relation_data(self, csr: CertificateSigningRequest) -> None: - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.debug("No relation: %s", self.relationship_name) - return - if not self.get_csrs_from_requirer_relation_data(): - logger.info("No CSRs in relation data - Doing nothing") - return - app_or_unit = self._get_app_or_unit() - try: - requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) - except DataValidationError: - logger.warning("Invalid relation data - Skipping removal of CSR") - return - new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests) - for requirer_csr in new_relation_data: - if requirer_csr.certificate_signing_request.strip() == str(csr).strip(): - new_relation_data.remove(requirer_csr) - try: - _RequirerData(certificate_signing_requests=new_relation_data).dump( - relation.data[app_or_unit] - ) - logger.info("Removed CSR from relation data") - except ModelError: - logger.warning("Failed to update relation data") - - def _get_app_or_unit(self) -> Union[Application, Unit]: - """Return the unit or app object based on the mode.""" - if self.mode == Mode.UNIT: - return self.model.unit - elif self.mode == Mode.APP: - return self.model.app - raise TLSCertificatesError("Invalid mode") - - @property - def private_key(self) -> Optional[PrivateKey]: - """Return the private key.""" - if self._private_key: - return self._private_key - if not self._private_key_generated(): - return None - secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) - private_key = secret.get_content(refresh=True)["private-key"] - return PrivateKey.from_string(private_key) - - def _ensure_private_key(self) -> None: - """Make sure there is a private key to be used. - - It will make sure there is a private key passed by the charm using the private_key - parameter or generate a new one otherwise. - """ - # Remove the generated private key - # if one has been passed by the charm using the private_key parameter - if self._private_key: - self._remove_private_key_secret() - return - if self._private_key_generated(): - logger.debug("Private key already generated") - return - self._generate_private_key() - - def regenerate_private_key(self) -> None: - """Regenerate the private key. - - Generate a new private key, remove old certificate requests and send new ones. - - Raises: - TLSCertificatesError: If the private key is passed by the charm using the - private_key parameter. - """ - if self._private_key: - raise TLSCertificatesError( - "Private key is passed by the charm through the private_key parameter, this function can't be used" - ) - if not self._private_key_generated(): - logger.warning("No private key to regenerate") - return - self._generate_private_key() - self._cleanup_certificate_requests() - self._send_certificate_requests() - - def _generate_private_key(self) -> None: - """Generate a new private key and store it in a secret. - - This is the case when the private key used is generated by the library. - and not passed by the charm using the private_key parameter. - """ - self._store_private_key_in_secret(generate_private_key()) - logger.info("Private key generated") - - def _private_key_generated(self) -> bool: - """Check if a private key is stored in a secret. - - This is the case when the private key used is generated by the library. - This should not exist when the private key used - is passed by the charm using the private_key parameter. - """ - try: - secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) - secret.get_content(refresh=True) - return True - except SecretNotFoundError: - return False - - def _store_private_key_in_secret(self, private_key: PrivateKey) -> None: - try: - secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) - secret.set_content({"private-key": str(private_key)}) - secret.get_content(refresh=True) - except SecretNotFoundError: - self.charm.unit.add_secret( - content={"private-key": str(private_key)}, - label=self._get_private_key_secret_label(), - ) - - def _remove_private_key_secret(self) -> None: - """Remove the private key secret.""" - try: - secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) - secret.remove_all_revisions() - except SecretNotFoundError: - logger.warning("Private key secret not found, nothing to remove") - - def _csr_matches_certificate_request( - self, certificate_signing_request: CertificateSigningRequest, is_ca: bool - ) -> bool: - for certificate_request in self.certificate_requests: - if certificate_request == CertificateRequestAttributes.from_csr( - certificate_signing_request, - is_ca, - ): - return True - return False - - def _certificate_requested(self, certificate_request: CertificateRequestAttributes) -> bool: - if not self.private_key: - return False - csr = self._certificate_requested_for_attributes(certificate_request) - if not csr: - return False - if not csr.certificate_signing_request.matches_private_key(key=self.private_key): - return False - return True - - def _certificate_requested_for_attributes( - self, - certificate_request: CertificateRequestAttributes, - ) -> Optional[RequirerCertificateRequest]: - for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if certificate_request == CertificateRequestAttributes.from_csr( - requirer_csr.certificate_signing_request, - requirer_csr.is_ca, - ): - return requirer_csr - return None - - def get_csrs_from_requirer_relation_data(self) -> List[RequirerCertificateRequest]: - """Return list of requirer's CSRs from relation data.""" - if self.mode == Mode.APP and not self.model.unit.is_leader(): - logger.debug("Not a leader unit - Skipping") - return [] - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.debug("No relation: %s", self.relationship_name) - return [] - app_or_unit = self._get_app_or_unit() - try: - requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) - except DataValidationError: - logger.warning("Invalid relation data") - return [] - requirer_csrs = [] - for csr in requirer_relation_data.certificate_signing_requests: - requirer_csrs.append( - RequirerCertificateRequest( - relation_id=relation.id, - certificate_signing_request=CertificateSigningRequest.from_string( - csr.certificate_signing_request - ), - is_ca=csr.ca if csr.ca else False, - ) - ) - return requirer_csrs - - def get_provider_certificates(self) -> List[ProviderCertificate]: - """Return list of certificates from the provider's relation data.""" - return self._load_provider_certificates() - - def _load_provider_certificates(self) -> List[ProviderCertificate]: - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.debug("No relation: %s", self.relationship_name) - return [] - if not relation.app: - logger.debug("No remote app in relation: %s", self.relationship_name) - return [] - try: - provider_relation_data = _ProviderApplicationData.load(relation.data[relation.app]) - except DataValidationError: - logger.warning("Invalid relation data") - return [] - return [ - certificate.to_provider_certificate(relation_id=relation.id) - for certificate in provider_relation_data.certificates - ] - - def _request_certificate(self, csr: CertificateSigningRequest, is_ca: bool) -> None: - """Add CSR to relation data.""" - if self.mode == Mode.APP and not self.model.unit.is_leader(): - logger.debug("Not a leader unit - Skipping") - return - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.debug("No relation: %s", self.relationship_name) - return - new_csr = _CertificateSigningRequest( - certificate_signing_request=str(csr).strip(), ca=is_ca - ) - app_or_unit = self._get_app_or_unit() - try: - requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) - except DataValidationError: - requirer_relation_data = _RequirerData( - certificate_signing_requests=[], - ) - new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests) - new_relation_data.append(new_csr) - try: - _RequirerData(certificate_signing_requests=new_relation_data).dump( - relation.data[app_or_unit] - ) - logger.info("Certificate signing request added to relation data.") - except ModelError: - logger.warning("Failed to update relation data") - - def _send_certificate_requests(self): - if not self.private_key: - logger.debug("Private key not generated yet.") - return - for certificate_request in self.certificate_requests: - if not self._certificate_requested(certificate_request): - csr = certificate_request.generate_csr( - private_key=self.private_key, - ) - if not csr: - logger.warning("Failed to generate CSR") - continue - self._request_certificate(csr=csr, is_ca=certificate_request.is_ca) - - def get_assigned_certificate( - self, certificate_request: CertificateRequestAttributes - ) -> Tuple[Optional[ProviderCertificate], Optional[PrivateKey]]: - """Get the certificate that was assigned to the given certificate request.""" - for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if certificate_request == CertificateRequestAttributes.from_csr( - requirer_csr.certificate_signing_request, - requirer_csr.is_ca, - ): - return self._find_certificate_in_relation_data(requirer_csr), self.private_key - return None, None - - def get_assigned_certificates( - self, - ) -> Tuple[List[ProviderCertificate], Optional[PrivateKey]]: - """Get a list of certificates that were assigned to this or app.""" - assigned_certificates = [] - for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if cert := self._find_certificate_in_relation_data(requirer_csr): - assigned_certificates.append(cert) - return assigned_certificates, self.private_key - - def _find_certificate_in_relation_data( - self, csr: RequirerCertificateRequest - ) -> Optional[ProviderCertificate]: - """Return the certificate that matches the given CSR, validated against the private key.""" - if not self.private_key: - return None - for provider_certificate in self.get_provider_certificates(): - if provider_certificate.certificate_signing_request == csr.certificate_signing_request: - if provider_certificate.certificate.is_ca and not csr.is_ca: - logger.warning("Non CA certificate requested, got a CA certificate, ignoring") - continue - elif not provider_certificate.certificate.is_ca and csr.is_ca: - logger.warning("CA certificate requested, got a non CA certificate, ignoring") - continue - if not provider_certificate.certificate.matches_private_key(self.private_key): - logger.warning( - "Certificate does not match the private key. Ignoring invalid certificate." - ) - continue - return provider_certificate - return None - - def _find_available_certificates(self): - """Find available certificates and emit events. - - This method will find certificates that are available for the requirer's CSRs. - If a certificate is found, it will be set as a secret and an event will be emitted. - If a certificate is revoked, the secret will be removed and an event will be emitted. - """ - requirer_csrs = self.get_csrs_from_requirer_relation_data() - csrs = [csr.certificate_signing_request for csr in requirer_csrs] - provider_certificates = self.get_provider_certificates() - for provider_certificate in provider_certificates: - if provider_certificate.certificate_signing_request in csrs: - secret_label = self._get_csr_secret_label( - provider_certificate.certificate_signing_request - ) - if provider_certificate.revoked: - with suppress(SecretNotFoundError): - logger.debug( - "Removing secret with label %s", - secret_label, - ) - secret = self.model.get_secret(label=secret_label) - secret.remove_all_revisions() - else: - if not self._csr_matches_certificate_request( - certificate_signing_request=provider_certificate.certificate_signing_request, - is_ca=provider_certificate.certificate.is_ca, - ): - logger.debug("Certificate requested for different attributes - Skipping") - continue - try: - secret = self.model.get_secret(label=secret_label) - logger.debug("Setting secret with label %s", secret_label) - # Juju < 3.6 will create a new revision even if the content is the same - if secret.get_content(refresh=True).get("certificate", "") == str( - provider_certificate.certificate - ): - logger.debug( - "Secret %s with correct certificate already exists", secret_label - ) - continue - secret.set_content( - content={ - "certificate": str(provider_certificate.certificate), - "csr": str(provider_certificate.certificate_signing_request), - } - ) - secret.set_info( - expire=calculate_relative_datetime( - target_time=provider_certificate.certificate.expiry_time, - fraction=self.renewal_relative_time, - ), - ) - secret.get_content(refresh=True) - except SecretNotFoundError: - logger.debug("Creating new secret with label %s", secret_label) - secret = self.charm.unit.add_secret( - content={ - "certificate": str(provider_certificate.certificate), - "csr": str(provider_certificate.certificate_signing_request), - }, - label=secret_label, - expire=calculate_relative_datetime( - target_time=provider_certificate.certificate.expiry_time, - fraction=self.renewal_relative_time, - ), - ) - self.on.certificate_available.emit( - certificate_signing_request=provider_certificate.certificate_signing_request, - certificate=provider_certificate.certificate, - ca=provider_certificate.ca, - chain=provider_certificate.chain, - ) - - def _cleanup_certificate_requests(self): - """Clean up certificate requests. - - Remove any certificate requests that falls into one of the following categories: - - The CSR attributes do not match any of the certificate requests defined in - the charm's certificate_requests attribute. - - The CSR public key does not match the private key. - """ - for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if not self._csr_matches_certificate_request( - certificate_signing_request=requirer_csr.certificate_signing_request, - is_ca=requirer_csr.is_ca, - ): - self._remove_requirer_csr_from_relation_data( - requirer_csr.certificate_signing_request - ) - logger.info( - "Removed CSR from relation data because it did not match any certificate request" # noqa: E501 - ) - elif ( - self.private_key - and not requirer_csr.certificate_signing_request.matches_private_key( - self.private_key - ) - ): - self._remove_requirer_csr_from_relation_data( - requirer_csr.certificate_signing_request - ) - logger.info( - "Removed CSR from relation data because it did not match the private key" - ) # noqa: E501 - - def _tls_relation_created(self) -> bool: - relation = self.model.get_relation(self.relationship_name) - if not relation: - return False - return True - - def _get_private_key_secret_label(self) -> str: - if self.mode == Mode.UNIT: - return f"{LIBID}-private-key-{self._get_unit_number()}-{self.relationship_name}" - elif self.mode == Mode.APP: - return f"{LIBID}-private-key-{self.relationship_name}" - else: - raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.") - - def _get_csr_secret_label(self, csr: CertificateSigningRequest) -> str: - csr_in_sha256_hex = csr.get_sha256_hex() - if self.mode == Mode.UNIT: - return f"{LIBID}-certificate-{self._get_unit_number()}-{csr_in_sha256_hex}" - elif self.mode == Mode.APP: - return f"{LIBID}-certificate-{csr_in_sha256_hex}" - else: - raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.") - - def _get_unit_number(self) -> str: - return self.model.unit.name.split("/")[1] - - -class TLSCertificatesProvidesV4(Object): - """TLS certificates provider class to be instantiated by TLS certificates providers.""" - - def __init__(self, charm: CharmBase, relationship_name: str): - super().__init__(charm, relationship_name) - self.framework.observe(charm.on[relationship_name].relation_joined, self._configure) - self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) - self.framework.observe(charm.on.update_status, self._configure) - self.charm = charm - self.relationship_name = relationship_name - self._security_logger = _OWASPLogger(application=f"tls-certificates-{charm.app.name}") - - def _configure(self, _: EventBase) -> None: - """Handle update status and tls relation changed events. - - This is a common hook triggered on a regular basis. - - Revoke certificates for which no csr exists - """ - if not self.model.unit.is_leader(): - return - self._remove_certificates_for_which_no_csr_exists() - - def _remove_certificates_for_which_no_csr_exists(self) -> None: - provider_certificates = self.get_provider_certificates() - requirer_csrs = [ - request.certificate_signing_request for request in self.get_certificate_requests() - ] - for provider_certificate in provider_certificates: - if provider_certificate.certificate_signing_request not in requirer_csrs: - tls_relation = self._get_tls_relations( - relation_id=provider_certificate.relation_id - ) - self._remove_provider_certificate( - certificate=provider_certificate.certificate, - relation=tls_relation[0], - ) - - def _get_tls_relations(self, relation_id: Optional[int] = None) -> List[Relation]: - return ( - [ - relation - for relation in self.model.relations[self.relationship_name] - if relation.id == relation_id - ] - if relation_id is not None - else self.model.relations.get(self.relationship_name, []) - ) - - def get_certificate_requests( - self, relation_id: Optional[int] = None - ) -> List[RequirerCertificateRequest]: - """Load certificate requests from the relation data.""" - relations = self._get_tls_relations(relation_id) - requirer_csrs: List[RequirerCertificateRequest] = [] - for relation in relations: - for unit in relation.units: - requirer_csrs.extend(self._load_requirer_databag(relation, unit)) - requirer_csrs.extend(self._load_requirer_databag(relation, relation.app)) - return requirer_csrs - - def _load_requirer_databag( - self, relation: Relation, unit_or_app: Union[Application, Unit] - ) -> List[RequirerCertificateRequest]: - try: - requirer_relation_data = _RequirerData.load(relation.data.get(unit_or_app, {})) - except DataValidationError: - logger.debug("Invalid requirer relation data for %s", unit_or_app.name) - return [] - return [ - RequirerCertificateRequest( - relation_id=relation.id, - certificate_signing_request=CertificateSigningRequest.from_string( - csr.certificate_signing_request - ), - is_ca=csr.ca if csr.ca else False, - ) - for csr in requirer_relation_data.certificate_signing_requests - ] - - def _add_provider_certificate( - self, - relation: Relation, - provider_certificate: ProviderCertificate, - ) -> None: - chain = [str(certificate) for certificate in provider_certificate.chain] - if chain[0] != str(provider_certificate.certificate): - logger.warning( - "The order of the chain from the TLS Certificates Provider is incorrect. " - "The leaf certificate should be the first element of the chain." - ) - elif not chain_has_valid_order(chain): - logger.warning( - "The order of the chain from the TLS Certificates Provider is partially incorrect." - ) - new_certificate = _Certificate( - certificate=str(provider_certificate.certificate), - certificate_signing_request=str(provider_certificate.certificate_signing_request), - ca=str(provider_certificate.ca), - chain=chain, - ) - provider_certificates = self._load_provider_certificates(relation) - if new_certificate in provider_certificates: - logger.info("Certificate already in relation data - Doing nothing") - return - provider_certificates.append(new_certificate) - self._dump_provider_certificates(relation=relation, certificates=provider_certificates) - - def _load_provider_certificates(self, relation: Relation) -> List[_Certificate]: - try: - provider_relation_data = _ProviderApplicationData.load(relation.data[self.charm.app]) - except DataValidationError: - logger.debug("Invalid provider relation data") - return [] - return copy.deepcopy(provider_relation_data.certificates) - - def _dump_provider_certificates(self, relation: Relation, certificates: List[_Certificate]): - try: - _ProviderApplicationData(certificates=certificates).dump(relation.data[self.model.app]) - logger.info("Certificate relation data updated") - except ModelError: - logger.warning("Failed to update relation data") - - def _remove_provider_certificate( - self, - relation: Relation, - certificate: Optional[Certificate] = None, - certificate_signing_request: Optional[CertificateSigningRequest] = None, - ) -> None: - """Remove certificate based on certificate or certificate signing request.""" - provider_certificates = self._load_provider_certificates(relation) - for provider_certificate in provider_certificates: - if certificate and provider_certificate.certificate == str(certificate): - provider_certificates.remove(provider_certificate) - if ( - certificate_signing_request - and provider_certificate.certificate_signing_request - == str(certificate_signing_request) - ): - provider_certificates.remove(provider_certificate) - self._dump_provider_certificates(relation=relation, certificates=provider_certificates) - - def revoke_all_certificates(self) -> None: - """Revoke all certificates of this provider. - - This method is meant to be used when the Root CA has changed. - """ - if not self.model.unit.is_leader(): - logger.warning("Unit is not a leader - will not set relation data") - return - relations = self._get_tls_relations() - for relation in relations: - provider_certificates = self._load_provider_certificates(relation) - for certificate in provider_certificates: - certificate.revoked = True - self._dump_provider_certificates(relation=relation, certificates=provider_certificates) - self._security_logger.log_event( - event="all_certificates_revoked", - level=logging.WARNING, - description="All certificates revoked", - ) - - def set_relation_certificate( - self, - provider_certificate: ProviderCertificate, - ) -> None: - """Add certificates to relation data. - - Args: - provider_certificate (ProviderCertificate): ProviderCertificate object - - Returns: - None - """ - if not self.model.unit.is_leader(): - logger.warning("Unit is not a leader - will not set relation data") - return - certificates_relation = self.model.get_relation( - relation_name=self.relationship_name, relation_id=provider_certificate.relation_id - ) - if not certificates_relation: - raise TLSCertificatesError(f"Relation {self.relationship_name} does not exist") - self._remove_provider_certificate( - relation=certificates_relation, - certificate_signing_request=provider_certificate.certificate_signing_request, - ) - self._add_provider_certificate( - relation=certificates_relation, - provider_certificate=provider_certificate, - ) - self._security_logger.log_event( - event="certificate_provided", - level=logging.INFO, - description="Certificate provided to requirer", - relation_id=str(provider_certificate.relation_id), - common_name=provider_certificate.certificate.common_name, - ) - - def get_issued_certificates( - self, relation_id: Optional[int] = None - ) -> List[ProviderCertificate]: - """Return a List of issued (non revoked) certificates. - - Returns: - List: List of ProviderCertificate objects - """ - if not self.model.unit.is_leader(): - logger.warning("Unit is not a leader - will not read relation data") - return [] - provider_certificates = self.get_provider_certificates(relation_id=relation_id) - return [certificate for certificate in provider_certificates if not certificate.revoked] - - def get_provider_certificates( - self, relation_id: Optional[int] = None - ) -> List[ProviderCertificate]: - """Return a List of issued certificates.""" - certificates: List[ProviderCertificate] = [] - relations = self._get_tls_relations(relation_id) - for relation in relations: - if not relation.app: - logger.warning("Relation %s does not have an application", relation.id) - continue - for certificate in self._load_provider_certificates(relation): - certificates.append(certificate.to_provider_certificate(relation_id=relation.id)) - return certificates - - def get_unsolicited_certificates( - self, relation_id: Optional[int] = None - ) -> List[ProviderCertificate]: - """Return provider certificates for which no certificate requests exists. - - Those certificates should be revoked. - """ - unsolicited_certificates: List[ProviderCertificate] = [] - provider_certificates = self.get_provider_certificates(relation_id=relation_id) - requirer_csrs = self.get_certificate_requests(relation_id=relation_id) - list_of_csrs = [csr.certificate_signing_request for csr in requirer_csrs] - for certificate in provider_certificates: - if certificate.certificate_signing_request not in list_of_csrs: - unsolicited_certificates.append(certificate) - return unsolicited_certificates - - def get_outstanding_certificate_requests( - self, relation_id: Optional[int] = None - ) -> List[RequirerCertificateRequest]: - """Return CSR's for which no certificate has been issued. - - Args: - relation_id (int): Relation id - - Returns: - list: List of RequirerCertificateRequest objects. - """ - requirer_csrs = self.get_certificate_requests(relation_id=relation_id) - outstanding_csrs: List[RequirerCertificateRequest] = [] - for relation_csr in requirer_csrs: - if not self._certificate_issued_for_csr( - csr=relation_csr.certificate_signing_request, - relation_id=relation_id, - ): - outstanding_csrs.append(relation_csr) - return outstanding_csrs - - def _certificate_issued_for_csr( - self, csr: CertificateSigningRequest, relation_id: Optional[int] - ) -> bool: - """Check whether a certificate has been issued for a given CSR.""" - issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) - for issued_certificate in issued_certificates_per_csr: - if issued_certificate.certificate_signing_request == csr: - return csr.matches_certificate(issued_certificate.certificate) - return False diff --git a/dovecot-charm/uv.lock b/dovecot-charm/uv.lock index f785fe8..e7131a8 100644 --- a/dovecot-charm/uv.lock +++ b/dovecot-charm/uv.lock @@ -160,6 +160,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dd/92/4db19cd8bc94db51a115f7a2e3c46d96b991ca7ebe27207beac9a6570bc6/charmlibs_apt-1.0.0.post0-py3-none-any.whl", hash = "sha256:958e84719eb1feff539f058dc6c7af648c53c88b9ebe7c6157ec8d2bdf5fbfc6", size = 19287, upload-time = "2025-10-15T02:40:27.756Z" }, ] +[[package]] +name = "charmlibs-interfaces-tls-certificates" +version = "1.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, + { name = "ops" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/7e/166af1e71f2bf96482845a1806dc345cbc5507134a99ccbbae297f174e4b/charmlibs_interfaces_tls_certificates-1.8.1.tar.gz", hash = "sha256:f2bfabf3a3b4c18034941771733177b30e4742c06d7742d4bb30da6ead953f43", size = 148059, upload-time = "2026-02-27T13:46:50.086Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/17/1d1b0083800f4cc20f42e5d2763521d93975376499565c62da5276a80629/charmlibs_interfaces_tls_certificates-1.8.1-py3-none-any.whl", hash = "sha256:8e8fe047e02515d76f57a1d019056d72ce8c859c2ffb39a1e379cfc11fc048e6", size = 28208, upload-time = "2026-02-27T13:46:48.959Z" }, +] + [[package]] name = "charmlibs-systemd" version = "1.0.0" @@ -413,6 +427,7 @@ source = { virtual = "." } dependencies = [ { name = "charmhelpers" }, { name = "charmlibs-apt" }, + { name = "charmlibs-interfaces-tls-certificates" }, { name = "charmlibs-systemd" }, { name = "cryptography" }, { name = "jinja2" }, @@ -458,6 +473,7 @@ unit = [ requires-dist = [ { name = "charmhelpers", specifier = ">=1.2.1" }, { name = "charmlibs-apt", specifier = "==1.0.0.post0" }, + { name = "charmlibs-interfaces-tls-certificates", specifier = ">=1.8.1" }, { name = "charmlibs-systemd", specifier = "==1.0.0" }, { name = "cryptography", specifier = ">=46.0.6" }, { name = "jinja2" }, From 7c634f8985854c04a0bf810ac558e1137f39ad91 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 13 Apr 2026 15:12:42 +0300 Subject: [PATCH 04/39] fix: vale --- docs/release-notes/artifacts/pr-3-tls.yaml | 7 +++++-- docs/release-notes/release-notes-0004.rst | 4 ++-- dovecot-charm/src/charm.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/release-notes/artifacts/pr-3-tls.yaml b/docs/release-notes/artifacts/pr-3-tls.yaml index 83efdba..93cf8a7 100644 --- a/docs/release-notes/artifacts/pr-3-tls.yaml +++ b/docs/release-notes/artifacts/pr-3-tls.yaml @@ -1,11 +1,14 @@ +# Copyright 2026 Canonical Ltd. +# See LICENSE file for licensing details. + # Version of the artifact schema version_schema: 2 changes: -- title: Added TLS certificate integration via the certificates relation +- title: Added TLS certificate integration using the certificates relation author: alithethird type: major - description: Added TLS support using the tls-certificates-interface library. The charm requests certificates via the certificates relation, writes the cert and key to /etc/dovecot/private/, and restarts Dovecot automatically after installation. + description: Added TLS support using the tls-certificates-interface library. The charm requests certificates using the certificates relation, writes the cert and key to /etc/dovecot/private/, and restarts Dovecot automatically after installation. urls: pr: - "https://github.com/canonical/mailserver-operators/pull/4" diff --git a/docs/release-notes/release-notes-0004.rst b/docs/release-notes/release-notes-0004.rst index ab54e1e..ad62d48 100644 --- a/docs/release-notes/release-notes-0004.rst +++ b/docs/release-notes/release-notes-0004.rst @@ -7,7 +7,7 @@ These release notes cover new features and changes in Dovecot. Main features: -* Added TLS certificate integration via the ``certificates`` relation. +* Added TLS certificate integration using the ``certificates`` relation. See our :ref:`Release policy and schedule `. @@ -35,7 +35,7 @@ The following major and minor features were added in this release. Added TLS certificate integration via the ``certificates`` relation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Added TLS support to the Dovecot charm using the ``tls-certificates-interface`` library. When a ``certificates`` relation is established, the charm requests a certificate for the configured mailname and handles the ``certificate_available`` event by writing the certificate and private key to ``/etc/dovecot/private/``. The Dovecot service is automatically restarted after certificate installation. The Dovecot configuration template was updated to reference the certificate and key paths for IMAPS and POP3S listeners. +Added TLS support to the Dovecot charm using the ``tls-certificates-interface`` library. When a ``certificates`` relation is established, the charm requests a certificate for the configured mail name and handles the ``certificate_available`` event by writing the certificate and private key to ``/etc/dovecot/private/``. The Dovecot service is automatically restarted after certificate installation. The Dovecot configuration template was updated to reference the certificate and key paths for IMAPS and POP3S listeners. Relevant links: diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index 9588591..2c9900b 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -21,7 +21,7 @@ ) from ops.charm import CharmBase from ops.main import main -from ops.model import ActiveStatus, BlockedStatus, MaintenanceStatus +from ops.model import BlockedStatus, MaintenanceStatus from constants import ( DOVECOT_CONF_TARGET, From 2d793f29b314ac9071aa6732d5889e970b8123f8 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 13 Apr 2026 15:15:18 +0300 Subject: [PATCH 05/39] fix: vale --- docs/release-notes/release-notes-0004.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/release-notes/release-notes-0004.rst b/docs/release-notes/release-notes-0004.rst index ad62d48..4578801 100644 --- a/docs/release-notes/release-notes-0004.rst +++ b/docs/release-notes/release-notes-0004.rst @@ -32,8 +32,8 @@ Updates The following major and minor features were added in this release. -Added TLS certificate integration via the ``certificates`` relation -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Added TLS certificate integration using the ``certificates`` relation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Added TLS support to the Dovecot charm using the ``tls-certificates-interface`` library. When a ``certificates`` relation is established, the charm requests a certificate for the configured mail name and handles the ``certificate_available`` event by writing the certificate and private key to ``/etc/dovecot/private/``. The Dovecot service is automatically restarted after certificate installation. The Dovecot configuration template was updated to reference the certificate and key paths for IMAPS and POP3S listeners. From 0c52c1838bd22f6633496a67951f11ab3f53485b Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Tue, 14 Apr 2026 10:43:27 +0300 Subject: [PATCH 06/39] fix(tls): make ssl=required conditional on cert file presence When no certificates relation is active the TLS cert files do not exist yet. Unconditionally setting ssl=required in the dovecot config caused dovecot to fail to start, putting the unit in error on every upgrade-charm or install event before a cert is provisioned. Template now checks tls_enabled (passed from charm, true iff the .pem file exists in /etc/dovecot/private/) and falls back to ssl=yes when no cert is present. ssl=required is restored automatically on the next _reconcile after certificate_available writes the cert to disk. Also register --use-existing pytest option in tests/conftest.py so it can be passed on the command line without an 'unrecognised arguments' error. --- dovecot-charm/src/charm.py | 1 + dovecot-charm/templates/dovecot.conf.tmpl | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index 2c9900b..6fd84b6 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -190,6 +190,7 @@ def _setup_dovecot(self, dovecot_config: DovecotConfig) -> None: "mail_root": MAIL_ROOT, "mailname": dovecot_config.mailname, "postmaster_address": dovecot_config.postmaster_address, + "tls_enabled": (self.tls_cert_dir / f"{dovecot_config.mailname}.pem").exists(), } template = self.jinja.get_template(DOVECOT_CONF_TEMPLATE) contents = template.render(template_context) diff --git a/dovecot-charm/templates/dovecot.conf.tmpl b/dovecot-charm/templates/dovecot.conf.tmpl index 5dc5493..89a4f89 100644 --- a/dovecot-charm/templates/dovecot.conf.tmpl +++ b/dovecot-charm/templates/dovecot.conf.tmpl @@ -10,9 +10,13 @@ auth_verbose = yes auth_verbose_passwords = no # TODO: change to ssl = required once TLS relation is added (pr/5-tls) +{% if tls_enabled %} ssl = required ssl_cert = Date: Tue, 14 Apr 2026 12:29:54 +0300 Subject: [PATCH 07/39] fix(tests): ensure mountpoint check executes correctly in data persistence test --- dovecot-charm/tests/unit/test_charm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dovecot-charm/tests/unit/test_charm.py b/dovecot-charm/tests/unit/test_charm.py index 5cb19bd..1871251 100644 --- a/dovecot-charm/tests/unit/test_charm.py +++ b/dovecot-charm/tests/unit/test_charm.py @@ -1,5 +1,6 @@ # Copyright 2026 Canonical Ltd. # See LICENSE file for licensing details. +import dataclasses from subprocess import CalledProcessError # nosec from unittest.mock import MagicMock, patch From b9d86c7a8ad5e9d836dedbcd159a21a94bf87ef1 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Tue, 14 Apr 2026 14:50:27 +0300 Subject: [PATCH 08/39] fix(tests): update LUKS secret generation in dovecot charm tests --- docs/explanation/charm-state-diagrams.md | 238 ++++++++++++++++++++ dovecot-charm/tests/integration/conftest.py | 4 +- 2 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 docs/explanation/charm-state-diagrams.md diff --git a/docs/explanation/charm-state-diagrams.md b/docs/explanation/charm-state-diagrams.md new file mode 100644 index 0000000..400437e --- /dev/null +++ b/docs/explanation/charm-state-diagrams.md @@ -0,0 +1,238 @@ +# Dovecot Charm State Diagrams + +Based on `origin/pr/2-storage` + exception-based refactor. + +--- + +## Diagram 1 — Event → Handler → Unit Status + +Shows which Juju events trigger which handlers, and every possible `unit.status` outcome. +Actions and `replicas.relation_created` produce no status changes. + +```mermaid +flowchart TD + %% ── Juju events ────────────────────────────────────────────────────────── + EV_INSTALL([install]) + EV_CONFIG([config_changed]) + EV_UPGRADE([upgrade_charm]) + EV_STOR_ATT([mail_data_storage_attached]) + EV_STOR_DET([mail_data_storage_detaching]) + EV_PEER([replicas.relation_created]) + EV_ACTION([clear_queue action]) + + %% ── Handlers ───────────────────────────────────────────────────────────── + H_INSTALL[_on_install] + H_RECONCILE[_reconcile] + H_PEER[_on_peer_relation_created\nwrites unit-name to relation data] + H_ACTION[_on_clear_queue_action\nno status change] + + %% ── Status outcomes ────────────────────────────────────────────────────── + M_INSTALLING(["● Maintenance\nInstalling packages"]) + M_DEPS(["● Maintenance\nInstalling required dependencies"]) + M_DONE(["● Maintenance\nCharm installation done"]) + M_CONFIGURING(["● Maintenance\nConfiguring charm"]) + M_DOVECOT(["● Maintenance\nSetting up and configuring dovecot"]) + M_DOVECOT_OK(["● Maintenance\nDovecot configuration updated"]) + M_PROCMAIL(["● Maintenance\nSetting up and configuring procmail"]) + + B_CONFIG(["✖ Blocked\nInvalid charm configuration\n(mailname / postmaster-address /\nprimary-unit / luks-key)\nraised: ConfigurationError"]) + B_LUKS_DISABLED(["✖ Blocked\nmail-data not mounted;\nmanage-luks disabled\nraised: StorageError"]) + B_LUKS_FAILED(["✖ Blocked\nFailed to setup LUKS storage\nraised: StorageError"]) + B_LUKS_RT(["✖ Blocked\n\n(device missing / not block /\nluksFormat / open /\ndmsetup / mkfs / mount)\nraised: StorageError"]) + B_DOVECONF(["✖ Blocked\nInvalid Dovecot configuration\nraised: ConfigurationError"]) + B_POSTFIX(["✖ Blocked\nFailed to configure postfix:\n\nraised: ConfigurationError"]) + + ACTIVE(["✔ Active"]) + + SILENT["(no status change)\ndoveconf not yet installed\n— logs warning, returns"] + + %% ── Event wiring ───────────────────────────────────────────────────────── + EV_INSTALL --> H_INSTALL + EV_CONFIG --> H_RECONCILE + EV_UPGRADE --> H_RECONCILE + EV_STOR_ATT --> H_RECONCILE + EV_STOR_DET --> H_RECONCILE + EV_PEER --> H_PEER + EV_ACTION --> H_ACTION + + %% ── _on_install flow ───────────────────────────────────────────────────── + H_INSTALL --> M_INSTALLING + M_INSTALLING --> M_DEPS + M_DEPS --> M_DONE + M_DONE -->|"calls _reconcile"| H_RECONCILE + + %% ── _reconcile: storage+config try/except block ────────────────────────── + H_RECONCILE --> M_CONFIGURING + M_CONFIGURING -->|"ConfigurationError\n(_get_dovecot_config)"| B_CONFIG + M_CONFIGURING -->|"StorageError: not mounted\n(ensure_storage_ready)"| B_LUKS_DISABLED + M_CONFIGURING -->|"StorageError: CalledProcessError\n(ensure_storage_ready)"| B_LUKS_FAILED + M_CONFIGURING -->|"StorageError: RuntimeError\n(ensure_storage_ready)"| B_LUKS_RT + M_CONFIGURING -->|"shutil.which('doveconf') is None"| SILENT + M_CONFIGURING -->|"all pass → _setup_dovecot"| M_DOVECOT + + %% ── _reconcile: dovecot+procmail try/except block ──────────────────────── + M_DOVECOT -->|"ConfigurationError\n(doveconf -c fails)"| B_DOVECONF + M_DOVECOT -->|"validation OK\n→ service_reload(dovecot)"| M_DOVECOT_OK + M_DOVECOT_OK --> M_PROCMAIL + M_PROCMAIL -->|"ConfigurationError\n(postconf -e fails)"| B_POSTFIX + M_PROCMAIL -->|"service_reload(postfix) OK\n→ open_ports()"| ACTIVE + + %% ── Styles ─────────────────────────────────────────────────────────────── + classDef event fill:#dbeafe,stroke:#3b82f6,color:#1e3a5f + classDef handler fill:#f3f4f6,stroke:#6b7280,color:#111827 + classDef maint fill:#fef9c3,stroke:#ca8a04,color:#713f12 + classDef blocked fill:#fee2e2,stroke:#dc2626,color:#7f1d1d + classDef active fill:#dcfce7,stroke:#16a34a,color:#14532d + classDef silent fill:#f3f4f6,stroke:#9ca3af,color:#6b7280,stroke-dasharray:4 4 + + class EV_INSTALL,EV_CONFIG,EV_UPGRADE,EV_STOR_ATT,EV_STOR_DET,EV_PEER,EV_ACTION event + class H_INSTALL,H_RECONCILE,H_PEER,H_ACTION handler + class M_INSTALLING,M_DEPS,M_DONE,M_CONFIGURING,M_DOVECOT,M_DOVECOT_OK,M_PROCMAIL maint + class B_CONFIG,B_LUKS_DISABLED,B_LUKS_FAILED,B_LUKS_RT,B_DOVECONF,B_POSTFIX blocked + class ACTIVE active + class SILENT silent +``` + +--- + +## Diagram 2 — `_reconcile` Internal Call Chain + +Full execution path inside `_reconcile`, showing both `try/except` blocks and every branch. + +```mermaid +flowchart TD + START(["_reconcile(event) called\nconfig_changed / upgrade_charm /\nmail_data_storage_attached /\nmail_data_storage_detaching /\n[via _on_install]"]) + + S1["unit.status =\nMaintenance('Configuring charm')"] + + %% ── try block 1: config + storage ─────────────────────────────────────── + TRY1[/"try"/] + + S2["_get_dovecot_config()\nDovecotConfig.from_charm()"] + S2_RAISES["raises ConfigurationError\n'Invalid charm configuration…'\n(mailname / postmaster-address /\nprimary-unit / luks-key)"] + + S3["ensure_storage_ready(charm)\nstorage.py"] + + S3A{"manage_luks = False"} + S3A_MT{"_mail_storage_mounted()\nos.path.ismount('/srv/mail')"} + S3A_RAISE["raises StorageError\n'mail-data not mounted;\nmanage-luks disabled'"] + S3A_OK["return (proceed)"] + + S3B{"manage_luks = True\nshutil.which('cryptsetup')"} + S3B_NONE["None → log warning\nreturn (defer silently)"] + + S3C{"storages / dev_path\nvalid?"} + S3C_BAD["empty or None\nlog error\nreturn (no block)"] + + S3D["setup_luks_storage(luks_key, dev_path)"] + S3D_STEPS["① isLuks check\n② luksFormat if new (key via stdin)\n③ cryptsetup open if not mapped\n④ dmsetup mknodes\n⑤ blkid check for ext4\n⑥ mkfs.ext4 if no fs\n⑦ configure_file /etc/fstab\n⑧ mount → /srv/mail"] + S3D_CPE["CalledProcessError\nraises StorageError\n'Failed to setup LUKS storage'"] + S3D_RTE["RuntimeError\nraises StorageError(str(e))"] + S3D_OK["return (LUKS ready)"] + + S3E["teardown_detaching_storage(charm)"] + S3E_STEPS["if storages present → return (no-op)\nif storages gone:\n manage_luks + mounted → umount\n mapper exists → luksClose\nCalledProcessError → log only"] + + CATCH1["except CharmBlockedError as e\nunit.status = Blocked(str(e))\nreturn"] + + %% ── doveconf guard ─────────────────────────────────────────────────────── + S4{"shutil.which('doveconf')"} + S4_NONE["log warning\n'Dovecot not installed yet'\nreturn\n(stays in Maintenance\n'Configuring charm')"] + + %% ── try block 2: dovecot + procmail ───────────────────────────────────── + TRY2[/"try"/] + + S5A["_setup_dovecot(dovecot_config)"] + S5A_1["unit.status =\nMaintenance('Setting up and\nconfiguring dovecot')"] + S5A_2["render dovecot.conf.tmpl\nwrite → /etc/dovecot/conf.d/\n99-local-dovecot-charm.conf"] + S5A_3{"doveconf -c\n/etc/dovecot/conf.d/\n99-local-dovecot-charm.conf"} + S5A_RAISE["raises ConfigurationError\n'Invalid Dovecot configuration,\ncheck logs for details'"] + S5A_OK["service_reload('dovecot',\nrestart_on_failure=True)\nunit.status =\nMaintenance('Dovecot\nconfiguration updated')"] + + S5B["_setup_procmail()"] + S5B_1["unit.status =\nMaintenance('Setting up and\nconfiguring procmail')"] + S5B_2["mkdir /srv/mail (0o1777)\nrender procmailrc.tmpl\nwrite → /etc/procmailrc"] + S5B_3{"postconf -e\nmailbox_command=procmail…"} + S5B_RAISE["raises ConfigurationError\n'Failed to configure\npostfix: '"] + S5B_OK["service_reload('postfix',\nrestart_on_failure=True)"] + + CATCH2["except ConfigurationError as e\nunit.status = Blocked(str(e))\nreturn"] + + S5C["_open_ports()\ntcp: 143, 993, 110, 995, 4190, 9900"] + ACTIVE(["unit.status = Active()"]) + + %% ── Wiring ─────────────────────────────────────────────────────────────── + START --> S1 --> TRY1 --> S2 + S2 -->|"raises"| S2_RAISES + S2 -->|"ok"| S3 + + S3 --> S3A + S3A -->|"True"| S3A_MT + S3A_MT -->|"not mounted"| S3A_RAISE + S3A_MT -->|"mounted"| S3A_OK + + S3A -->|"False (manage_luks=True)"| S3B + S3B -->|"None"| S3B_NONE + S3B -->|"found"| S3C + S3C -->|"invalid"| S3C_BAD + S3C -->|"valid"| S3D + S3D --> S3D_STEPS + S3D_STEPS -->|"CalledProcessError"| S3D_CPE + S3D_STEPS -->|"RuntimeError"| S3D_RTE + S3D_STEPS -->|"success"| S3D_OK + + S3A_OK & S3B_NONE & S3C_BAD & S3D_OK --> S3E + S3E --> S3E_STEPS + + S2_RAISES & S3A_RAISE & S3D_CPE & S3D_RTE --> CATCH1 + + S3E_STEPS --> S4 + S4 -->|"None"| S4_NONE + S4 -->|"found"| TRY2 + + TRY2 --> S5A --> S5A_1 --> S5A_2 --> S5A_3 + S5A_3 -->|"non-zero exit"| S5A_RAISE + S5A_3 -->|"exit 0"| S5A_OK --> S5B + S5B --> S5B_1 --> S5B_2 --> S5B_3 + S5B_3 -->|"CalledProcessError"| S5B_RAISE + S5B_3 -->|"success"| S5B_OK + + S5A_RAISE & S5B_RAISE --> CATCH2 + + S5B_OK --> S5C --> ACTIVE + + %% ── Styles ─────────────────────────────────────────────────────────────── + classDef tryblock fill:#ede9fe,stroke:#7c3aed,color:#3b0764 + classDef catch fill:#fee2e2,stroke:#dc2626,color:#7f1d1d + classDef decision fill:#e0f2fe,stroke:#0284c7,color:#0c4a6e + classDef action fill:#f3f4f6,stroke:#6b7280,color:#111827 + classDef maint fill:#fef9c3,stroke:#ca8a04,color:#713f12 + classDef blocked fill:#fee2e2,stroke:#dc2626,color:#7f1d1d + classDef active fill:#dcfce7,stroke:#16a34a,color:#14532d + classDef silent fill:#f3f4f6,stroke:#9ca3af,color:#6b7280,stroke-dasharray:4 4 + classDef start fill:#dbeafe,stroke:#3b82f6,color:#1e3a5f + classDef raises fill:#fef3c7,stroke:#d97706,color:#78350f + + class START start + class TRY1,TRY2 tryblock + class CATCH1,CATCH2 catch + class S3A,S3A_MT,S3B,S3C,S5A_3,S5B_3,S4 decision + class S2,S3,S3D,S3D_STEPS,S3E,S3E_STEPS,S5A,S5A_1,S5A_2,S5A_OK,S5B,S5B_1,S5B_2,S5B_OK,S5C action + class S5A_1,S5B_1 maint + class S2_RAISES,S3A_RAISE,S3D_CPE,S3D_RTE,S5A_RAISE,S5B_RAISE raises + class ACTIVE active + class S3B_NONE,S3C_BAD,S4_NONE silent +``` + +--- + +## Notes + +- **`_on_install`** no longer guards on config — just installs packages then calls `_reconcile`. Config blocking handled entirely inside `_reconcile`. +- **`_configure` deleted** — inlined into `_reconcile` as second `try/except` block. +- **Status written only in `_reconcile`** catch blocks (and transient Maintenance in individual setup methods). No function outside `_reconcile`/`_on_install` writes Blocked directly. +- **Exception hierarchy:** `StorageError` and `ConfigurationError` both extend `CharmBlockedError`. First `try/except` catches `CharmBlockedError` (both types). Second catches `ConfigurationError` only. +- **`teardown_detaching_storage`** never raises — `CalledProcessError` during umount/luksClose is logged and swallowed. Not in either try block. +- **Silent hang** remains: if `doveconf` absent, unit stays in `Maintenance("Configuring charm")` until next event. +- **No `WaitingStatus`** used anywhere. +- **LUKS key** fetched from Juju secret at config-validation time; passed to `cryptsetup` via stdin. diff --git a/dovecot-charm/tests/integration/conftest.py b/dovecot-charm/tests/integration/conftest.py index cda7b11..52861a5 100644 --- a/dovecot-charm/tests/integration/conftest.py +++ b/dovecot-charm/tests/integration/conftest.py @@ -3,6 +3,7 @@ import logging import typing +from secrets import token_hex import jubilant import pytest @@ -49,11 +50,12 @@ def dovecot_charm( ) -> str: """Build and deploy the charm.""" logging.info(f"Checking for existing application {APP_NAME}...") + luks_key = token_hex(16) if not juju.status().apps.get(APP_NAME): logging.info(f"Application {APP_NAME} not found, proceeding with deployment.") - secret_id = juju.cli("add-secret", "dovecot-luks-key", "key=s3cr3tpassphrase").strip() + secret_id = juju.cli("add-secret", "dovecot-luks-key", f"key={luks_key}").strip() logging.info(f"Created LUKS secret: {secret_id}") config = { From 839d37eedf12303319660bd1e543b56b8ba88144 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 07:31:09 +0300 Subject: [PATCH 09/39] fix(test): mock ensure_storage_ready in TLS tests to prevent PermissionError --- dovecot-charm/tests/unit/test_charm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dovecot-charm/tests/unit/test_charm.py b/dovecot-charm/tests/unit/test_charm.py index 1871251..6e80d1a 100644 --- a/dovecot-charm/tests/unit/test_charm.py +++ b/dovecot-charm/tests/unit/test_charm.py @@ -136,6 +136,7 @@ def test_certificate_available_writes_files(ctx, base_state, tmp_path): patch("charm.DovecotCharm._install"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + patch("charm.ensure_storage_ready"), patch("charm.systemd.service_reload", return_value=True), ctx(ctx.on.config_changed(), base_state) as mgr, ): @@ -155,6 +156,7 @@ def test_certificate_available_no_mailname_returns(ctx, base_state): patch("charm.DovecotCharm._install"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + patch("charm.ensure_storage_ready"), patch("charm.systemd.service_reload") as mock_service_reload, ctx(ctx.on.config_changed(), state_in) as mgr, ): @@ -168,6 +170,7 @@ def test_certificate_available_restarts_dovecot(ctx, base_state, tmp_path): patch("charm.DovecotCharm._install"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + patch("charm.ensure_storage_ready"), patch("charm.systemd.service_reload", return_value=True) as mock_service_reload, ctx(ctx.on.config_changed(), base_state) as mgr, ): From d255443b6227cbfb85523bf463cfaaaa6d0e7e1e Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 08:22:13 +0300 Subject: [PATCH 10/39] refactor(tls): make TLS mandatory via _setup_tls in _reconcile - Replace _on_certificate_available with _setup_tls called from _reconcile - Wire certificate_available event to _reconcile (not separate handler) - Always ssl=required in dovecot.conf (no conditional) - Add TLS_CERT_DIR constant, use get_assigned_certificate() API - Charm blocks with distinct messages for missing relation vs missing cert - New tests/unit/test_tls.py with 6 tests following SKILL.md principles - Fix integration tests: copyright, sleeps, stat quoting, ssl assertion - Update state diagrams for TLS states and events --- .opencode/plans/pr3-tls-review-fixes.md | 19 +++ docs/explanation/charm-state-diagrams.md | 44 ++++-- dovecot-charm/src/charm.py | 65 ++++---- dovecot-charm/src/constants.py | 2 + dovecot-charm/templates/dovecot.conf.tmpl | 5 - dovecot-charm/tests/integration/test_tls.py | 16 +- dovecot-charm/tests/unit/test_charm.py | 86 ++--------- dovecot-charm/tests/unit/test_storage.py | 12 ++ dovecot-charm/tests/unit/test_tls.py | 160 ++++++++++++++++++++ 9 files changed, 283 insertions(+), 126 deletions(-) create mode 100644 .opencode/plans/pr3-tls-review-fixes.md create mode 100644 dovecot-charm/tests/unit/test_tls.py diff --git a/.opencode/plans/pr3-tls-review-fixes.md b/.opencode/plans/pr3-tls-review-fixes.md new file mode 100644 index 0000000..e086736 --- /dev/null +++ b/.opencode/plans/pr3-tls-review-fixes.md @@ -0,0 +1,19 @@ +# Plan: pr/3-tls Review Fixes + +**Branch:** `pr/3-tls` (rebased on `origin/main` after pr/2 merge) +**Problem:** The 9 commits have the original (pre-review) TLS implementation. All review fixes were lost during rebase (never committed). +**Goal:** Re-apply all fixes, commit, push. + +## Steps + +1. `constants.py` — add `TLS_CERT_DIR = Path("/etc/dovecot/private")` +2. `charm.py` — delete `_on_certificate_available`, create `_setup_tls`, wire `certificate_available` → `_reconcile`, use `TLS_CERT_DIR` constant, remove `tls_enabled` from template context +3. `dovecot.conf.tmpl` — always `ssl = required`, remove conditional and stale TODO +4. `test_charm.py` — delete 3 old TLS tests, replace `_install` patches with `_setup_tls`, add comments +5. `test_storage.py` — add `_setup_tls` patches to 6 tests reaching ActiveStatus +6. `tests/unit/test_tls.py` (new) — 6 tests following SKILL.md principles +7. `tests/integration/test_tls.py` — fix copyright, remove sleeps, fix stat quoting, add ssl=required assertion +8. `docs/explanation/charm-state-diagrams.md` — update for TLS states +9. Remove `dovecot-charm-state-diagrams.rst` if exists +10. Run `tox -e fmt,unit,lint` — all must pass +11. Commit and push diff --git a/docs/explanation/charm-state-diagrams.md b/docs/explanation/charm-state-diagrams.md index 400437e..b914c83 100644 --- a/docs/explanation/charm-state-diagrams.md +++ b/docs/explanation/charm-state-diagrams.md @@ -1,6 +1,6 @@ # Dovecot Charm State Diagrams -Based on `origin/pr/2-storage` + exception-based refactor. +Based on `pr/3-tls`: storage + TLS + exception-based reconcile. --- @@ -17,6 +17,7 @@ flowchart TD EV_UPGRADE([upgrade_charm]) EV_STOR_ATT([mail_data_storage_attached]) EV_STOR_DET([mail_data_storage_detaching]) + EV_CERT_AVAIL([certificate_available]) EV_PEER([replicas.relation_created]) EV_ACTION([clear_queue action]) @@ -39,6 +40,8 @@ flowchart TD B_LUKS_DISABLED(["✖ Blocked\nmail-data not mounted;\nmanage-luks disabled\nraised: StorageError"]) B_LUKS_FAILED(["✖ Blocked\nFailed to setup LUKS storage\nraised: StorageError"]) B_LUKS_RT(["✖ Blocked\n\n(device missing / not block /\nluksFormat / open /\ndmsetup / mkfs / mount)\nraised: StorageError"]) + B_TLS_NO_REL(["✖ Blocked\nTLS certificates relation not available.\nIntegrate with a TLS provider.\nraised: ConfigurationError"]) + B_TLS_NO_CERT(["✖ Blocked\nTLS certificate not yet available\nfrom the certificates relation.\nraised: ConfigurationError"]) B_DOVECONF(["✖ Blocked\nInvalid Dovecot configuration\nraised: ConfigurationError"]) B_POSTFIX(["✖ Blocked\nFailed to configure postfix:\n\nraised: ConfigurationError"]) @@ -52,6 +55,7 @@ flowchart TD EV_UPGRADE --> H_RECONCILE EV_STOR_ATT --> H_RECONCILE EV_STOR_DET --> H_RECONCILE + EV_CERT_AVAIL --> H_RECONCILE EV_PEER --> H_PEER EV_ACTION --> H_ACTION @@ -68,9 +72,11 @@ flowchart TD M_CONFIGURING -->|"StorageError: CalledProcessError\n(ensure_storage_ready)"| B_LUKS_FAILED M_CONFIGURING -->|"StorageError: RuntimeError\n(ensure_storage_ready)"| B_LUKS_RT M_CONFIGURING -->|"shutil.which('doveconf') is None"| SILENT - M_CONFIGURING -->|"all pass → _setup_dovecot"| M_DOVECOT + M_CONFIGURING -->|"all pass → _setup_tls"| B_TLS_NO_REL + M_CONFIGURING -->|"all pass → _setup_tls"| B_TLS_NO_CERT + M_CONFIGURING -->|"tls cert written → _setup_dovecot"| M_DOVECOT - %% ── _reconcile: dovecot+procmail try/except block ──────────────────────── + %% ── _reconcile: tls+dovecot+procmail try/except block ──────────────────── M_DOVECOT -->|"ConfigurationError\n(doveconf -c fails)"| B_DOVECONF M_DOVECOT -->|"validation OK\n→ service_reload(dovecot)"| M_DOVECOT_OK M_DOVECOT_OK --> M_PROCMAIL @@ -85,10 +91,10 @@ flowchart TD classDef active fill:#dcfce7,stroke:#16a34a,color:#14532d classDef silent fill:#f3f4f6,stroke:#9ca3af,color:#6b7280,stroke-dasharray:4 4 - class EV_INSTALL,EV_CONFIG,EV_UPGRADE,EV_STOR_ATT,EV_STOR_DET,EV_PEER,EV_ACTION event + class EV_INSTALL,EV_CONFIG,EV_UPGRADE,EV_STOR_ATT,EV_STOR_DET,EV_CERT_AVAIL,EV_PEER,EV_ACTION event class H_INSTALL,H_RECONCILE,H_PEER,H_ACTION handler class M_INSTALLING,M_DEPS,M_DONE,M_CONFIGURING,M_DOVECOT,M_DOVECOT_OK,M_PROCMAIL maint - class B_CONFIG,B_LUKS_DISABLED,B_LUKS_FAILED,B_LUKS_RT,B_DOVECONF,B_POSTFIX blocked + class B_CONFIG,B_LUKS_DISABLED,B_LUKS_FAILED,B_LUKS_RT,B_TLS_NO_REL,B_TLS_NO_CERT,B_DOVECONF,B_POSTFIX blocked class ACTIVE active class SILENT silent ``` @@ -101,7 +107,7 @@ Full execution path inside `_reconcile`, showing both `try/except` blocks and ev ```mermaid flowchart TD - START(["_reconcile(event) called\nconfig_changed / upgrade_charm /\nmail_data_storage_attached /\nmail_data_storage_detaching /\n[via _on_install]"]) + START(["_reconcile(event) called\nconfig_changed / upgrade_charm /\nmail_data_storage_attached /\nmail_data_storage_detaching /\ncertificate_available /\n[via _on_install]"]) S1["unit.status =\nMaintenance('Configuring charm')"] @@ -139,12 +145,17 @@ flowchart TD S4{"shutil.which('doveconf')"} S4_NONE["log warning\n'Dovecot not installed yet'\nreturn\n(stays in Maintenance\n'Configuring charm')"] - %% ── try block 2: dovecot + procmail ───────────────────────────────────── + %% ── try block 2: tls + dovecot + procmail ─────────────────────────────── TRY2[/"try"/] + S5TLS["_setup_tls(dovecot_config)"] + S5TLS_NO_REL["raises ConfigurationError\n'TLS certificates relation\nnot available…'"] + S5TLS_NO_CERT["raises ConfigurationError\n'TLS certificate not yet\navailable…'"] + S5TLS_OK["write cert → /etc/dovecot/private/.pem (0o644)\nwrite key → /etc/dovecot/private/.key (0o600)"] + S5A["_setup_dovecot(dovecot_config)"] S5A_1["unit.status =\nMaintenance('Setting up and\nconfiguring dovecot')"] - S5A_2["render dovecot.conf.tmpl\nwrite → /etc/dovecot/conf.d/\n99-local-dovecot-charm.conf"] + S5A_2["render dovecot.conf.tmpl\n(ssl=required, mailname cert paths)\nwrite → /etc/dovecot/conf.d/\n99-local-dovecot-charm.conf"] S5A_3{"doveconf -c\n/etc/dovecot/conf.d/\n99-local-dovecot-charm.conf"} S5A_RAISE["raises ConfigurationError\n'Invalid Dovecot configuration,\ncheck logs for details'"] S5A_OK["service_reload('dovecot',\nrestart_on_failure=True)\nunit.status =\nMaintenance('Dovecot\nconfiguration updated')"] @@ -190,14 +201,18 @@ flowchart TD S4 -->|"None"| S4_NONE S4 -->|"found"| TRY2 - TRY2 --> S5A --> S5A_1 --> S5A_2 --> S5A_3 + TRY2 --> S5TLS + S5TLS -->|"_tls is None"| S5TLS_NO_REL + S5TLS -->|"get_assigned_certificate\nreturns (None,None)"| S5TLS_NO_CERT + S5TLS -->|"cert+key obtained"| S5TLS_OK --> S5A + S5A --> S5A_1 --> S5A_2 --> S5A_3 S5A_3 -->|"non-zero exit"| S5A_RAISE S5A_3 -->|"exit 0"| S5A_OK --> S5B S5B --> S5B_1 --> S5B_2 --> S5B_3 S5B_3 -->|"CalledProcessError"| S5B_RAISE S5B_3 -->|"success"| S5B_OK - S5A_RAISE & S5B_RAISE --> CATCH2 + S5TLS_NO_REL & S5TLS_NO_CERT & S5A_RAISE & S5B_RAISE --> CATCH2 S5B_OK --> S5C --> ACTIVE @@ -216,10 +231,10 @@ flowchart TD class START start class TRY1,TRY2 tryblock class CATCH1,CATCH2 catch - class S3A,S3A_MT,S3B,S3C,S5A_3,S5B_3,S4 decision - class S2,S3,S3D,S3D_STEPS,S3E,S3E_STEPS,S5A,S5A_1,S5A_2,S5A_OK,S5B,S5B_1,S5B_2,S5B_OK,S5C action + class S3A,S3A_MT,S3B,S3C,S5A_3,S5B_3,S4,S5TLS decision + class S2,S3,S3D,S3D_STEPS,S3E,S3E_STEPS,S5TLS_OK,S5A,S5A_1,S5A_2,S5A_OK,S5B,S5B_1,S5B_2,S5B_OK,S5C action class S5A_1,S5B_1 maint - class S2_RAISES,S3A_RAISE,S3D_CPE,S3D_RTE,S5A_RAISE,S5B_RAISE raises + class S2_RAISES,S3A_RAISE,S3D_CPE,S3D_RTE,S5TLS_NO_REL,S5TLS_NO_CERT,S5A_RAISE,S5B_RAISE raises class ACTIVE active class S3B_NONE,S3C_BAD,S4_NONE silent ``` @@ -230,6 +245,9 @@ flowchart TD - **`_on_install`** no longer guards on config — just installs packages then calls `_reconcile`. Config blocking handled entirely inside `_reconcile`. - **`_configure` deleted** — inlined into `_reconcile` as second `try/except` block. +- **`certificate_available` wired to `_reconcile`** — same handler as all other events. No separate `_on_certificate_available`. +- **TLS is mandatory**: `ssl = required` always in dovecot.conf. The charm will not reach `ActiveStatus` without a working `certificates` relation that has issued a cert. +- **`_setup_tls`** runs first in the second try block — writes cert+key from relation data to `/etc/dovecot/private/` before dovecot config is rendered or validated. - **Status written only in `_reconcile`** catch blocks (and transient Maintenance in individual setup methods). No function outside `_reconcile`/`_on_install` writes Blocked directly. - **Exception hierarchy:** `StorageError` and `ConfigurationError` both extend `CharmBlockedError`. First `try/except` catches `CharmBlockedError` (both types). Second catches `ConfigurationError` only. - **`teardown_detaching_storage`** never raises — `CalledProcessError` during umount/luksClose is logged and swallowed. Not in either try block. diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index 6fd84b6..0b5fddf 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -15,7 +15,6 @@ from charmhelpers.core import host from charmlibs import apt, systemd from charmlibs.interfaces.tls_certificates import ( - CertificateAvailableEvent, CertificateRequestAttributes, TLSCertificatesRequiresV4, ) @@ -35,6 +34,7 @@ PROCMAILRC_TEMPLATE, REQUIRED_PACKAGES, TEMPLATES_DIR, + TLS_CERT_DIR, ) from dovecot_config import DovecotConfig, DovecotConfigInvalidError, DovecotConfigSecretError from exceptions import CharmBlockedError, ConfigurationError @@ -67,9 +67,6 @@ def __init__(self, *args): loader=jinja2.FileSystemLoader(TEMPLATES_DIR), autoescape=True ) - # TLS certificates directory - self.tls_cert_dir = Path("/etc/dovecot/private") - # TLS certificates integration self._tls = None mailname = self.config.get("mailname", "") @@ -85,9 +82,7 @@ def __init__(self, *args): ], refresh_events=[self.on.config_changed], ) - self.framework.observe( - self._tls.on.certificate_available, self._on_certificate_available - ) + self.framework.observe(self._tls.on.certificate_available, self._reconcile) def get_units(self) -> typing.List[str]: """Return a list of all units in the application. @@ -153,6 +148,7 @@ def _reconcile(self, event): logger.warning("Dovecot not installed yet, deferring configuration") return try: + self._setup_tls(dovecot_config) self._setup_dovecot(dovecot_config) self._setup_procmail() except ConfigurationError as e: @@ -190,7 +186,6 @@ def _setup_dovecot(self, dovecot_config: DovecotConfig) -> None: "mail_root": MAIL_ROOT, "mailname": dovecot_config.mailname, "postmaster_address": dovecot_config.postmaster_address, - "tls_enabled": (self.tls_cert_dir / f"{dovecot_config.mailname}.pem").exists(), } template = self.jinja.get_template(DOVECOT_CONF_TEMPLATE) contents = template.render(template_context) @@ -276,34 +271,46 @@ def _on_clear_queue_action(self, event): logger.exception(f"Failed to clear Postfix queue: {e.stderr}") event.fail(f"Failed to run postsuper: {e.stderr}") - def _on_certificate_available(self, event: CertificateAvailableEvent): - """Handle TLS certificate available event.""" - mailname = self.config.get("mailname", "") - if not mailname: - logger.warning("Certificate available but mailname not configured") - return + def _setup_tls(self, dovecot_config: DovecotConfig) -> None: + """Write TLS cert+key to disk from the certificates relation. - self.tls_cert_dir.mkdir(parents=True, exist_ok=True) + Called from _reconcile before _setup_dovecot so the cert files are + present when dovecot.conf is rendered and validated. - cert_path = self.tls_cert_dir / f"{mailname}.pem" - key_path = self.tls_cert_dir / f"{mailname}.key" + Raises: + ConfigurationError: If no TLS relation exists or the certificate + has not been issued yet. + """ + if not self._tls: + raise ConfigurationError( + "TLS certificates relation not available. " + "Integrate with a TLS provider using the 'certificates' relation." + ) + + cert_request = CertificateRequestAttributes( + common_name=dovecot_config.mailname, + sans_dns=frozenset([dovecot_config.mailname]), + ) + provider_cert, private_key = self._tls.get_assigned_certificate(cert_request) + if not provider_cert or not private_key: + raise ConfigurationError( + "TLS certificate not yet available from the certificates relation." + ) - cert_content = str(event.certificate.certificate) - if event.certificate.ca: - cert_content += "\n" + str(event.certificate.ca) + TLS_CERT_DIR.mkdir(parents=True, exist_ok=True) + cert_path = TLS_CERT_DIR / f"{dovecot_config.mailname}.pem" + key_path = TLS_CERT_DIR / f"{dovecot_config.mailname}.key" + cert_content = str(provider_cert.certificate) + if provider_cert.ca: + cert_content += "\n" + str(provider_cert.ca) cert_path.write_text(cert_content) cert_path.chmod(0o644) - logger.info(f"Certificate written to {cert_path}") - - private_key = self._tls.private_key - if private_key: - key_path.write_text(str(private_key)) - key_path.chmod(0o600) - logger.info(f"Private key written to {key_path}") + logger.info(f"TLS certificate written to {cert_path}") - if systemd.service_reload("dovecot"): - logger.info("Dovecot service reloaded with new TLS certificate") + key_path.write_text(str(private_key)) + key_path.chmod(0o600) + logger.info(f"TLS private key written to {key_path}") if __name__ == "__main__": # pragma: nocover diff --git a/dovecot-charm/src/constants.py b/dovecot-charm/src/constants.py index a02104b..4df5950 100644 --- a/dovecot-charm/src/constants.py +++ b/dovecot-charm/src/constants.py @@ -45,3 +45,5 @@ # start hook can re-open LUKS without relying on `storage-get` (which fails # when Juju has not yet re-provisioned the storage after a VM restart). STORAGE_DEV_PATH_FILE = "/var/lib/dovecot-charm/storage-dev-path" + +TLS_CERT_DIR = Path("/etc/dovecot/private") diff --git a/dovecot-charm/templates/dovecot.conf.tmpl b/dovecot-charm/templates/dovecot.conf.tmpl index 89a4f89..247e2ef 100644 --- a/dovecot-charm/templates/dovecot.conf.tmpl +++ b/dovecot-charm/templates/dovecot.conf.tmpl @@ -9,14 +9,9 @@ default_vsz_limit = 256M auth_verbose = yes auth_verbose_passwords = no -# TODO: change to ssl = required once TLS relation is added (pr/5-tls) -{% if tls_enabled %} ssl = required ssl_cert = Date: Mon, 20 Apr 2026 08:31:52 +0300 Subject: [PATCH 11/39] delete(docs): remove Dovecot charm state diagrams documentation --- docs/explanation/charm-state-diagrams.md | 256 ----------------------- 1 file changed, 256 deletions(-) delete mode 100644 docs/explanation/charm-state-diagrams.md diff --git a/docs/explanation/charm-state-diagrams.md b/docs/explanation/charm-state-diagrams.md deleted file mode 100644 index b914c83..0000000 --- a/docs/explanation/charm-state-diagrams.md +++ /dev/null @@ -1,256 +0,0 @@ -# Dovecot Charm State Diagrams - -Based on `pr/3-tls`: storage + TLS + exception-based reconcile. - ---- - -## Diagram 1 — Event → Handler → Unit Status - -Shows which Juju events trigger which handlers, and every possible `unit.status` outcome. -Actions and `replicas.relation_created` produce no status changes. - -```mermaid -flowchart TD - %% ── Juju events ────────────────────────────────────────────────────────── - EV_INSTALL([install]) - EV_CONFIG([config_changed]) - EV_UPGRADE([upgrade_charm]) - EV_STOR_ATT([mail_data_storage_attached]) - EV_STOR_DET([mail_data_storage_detaching]) - EV_CERT_AVAIL([certificate_available]) - EV_PEER([replicas.relation_created]) - EV_ACTION([clear_queue action]) - - %% ── Handlers ───────────────────────────────────────────────────────────── - H_INSTALL[_on_install] - H_RECONCILE[_reconcile] - H_PEER[_on_peer_relation_created\nwrites unit-name to relation data] - H_ACTION[_on_clear_queue_action\nno status change] - - %% ── Status outcomes ────────────────────────────────────────────────────── - M_INSTALLING(["● Maintenance\nInstalling packages"]) - M_DEPS(["● Maintenance\nInstalling required dependencies"]) - M_DONE(["● Maintenance\nCharm installation done"]) - M_CONFIGURING(["● Maintenance\nConfiguring charm"]) - M_DOVECOT(["● Maintenance\nSetting up and configuring dovecot"]) - M_DOVECOT_OK(["● Maintenance\nDovecot configuration updated"]) - M_PROCMAIL(["● Maintenance\nSetting up and configuring procmail"]) - - B_CONFIG(["✖ Blocked\nInvalid charm configuration\n(mailname / postmaster-address /\nprimary-unit / luks-key)\nraised: ConfigurationError"]) - B_LUKS_DISABLED(["✖ Blocked\nmail-data not mounted;\nmanage-luks disabled\nraised: StorageError"]) - B_LUKS_FAILED(["✖ Blocked\nFailed to setup LUKS storage\nraised: StorageError"]) - B_LUKS_RT(["✖ Blocked\n\n(device missing / not block /\nluksFormat / open /\ndmsetup / mkfs / mount)\nraised: StorageError"]) - B_TLS_NO_REL(["✖ Blocked\nTLS certificates relation not available.\nIntegrate with a TLS provider.\nraised: ConfigurationError"]) - B_TLS_NO_CERT(["✖ Blocked\nTLS certificate not yet available\nfrom the certificates relation.\nraised: ConfigurationError"]) - B_DOVECONF(["✖ Blocked\nInvalid Dovecot configuration\nraised: ConfigurationError"]) - B_POSTFIX(["✖ Blocked\nFailed to configure postfix:\n\nraised: ConfigurationError"]) - - ACTIVE(["✔ Active"]) - - SILENT["(no status change)\ndoveconf not yet installed\n— logs warning, returns"] - - %% ── Event wiring ───────────────────────────────────────────────────────── - EV_INSTALL --> H_INSTALL - EV_CONFIG --> H_RECONCILE - EV_UPGRADE --> H_RECONCILE - EV_STOR_ATT --> H_RECONCILE - EV_STOR_DET --> H_RECONCILE - EV_CERT_AVAIL --> H_RECONCILE - EV_PEER --> H_PEER - EV_ACTION --> H_ACTION - - %% ── _on_install flow ───────────────────────────────────────────────────── - H_INSTALL --> M_INSTALLING - M_INSTALLING --> M_DEPS - M_DEPS --> M_DONE - M_DONE -->|"calls _reconcile"| H_RECONCILE - - %% ── _reconcile: storage+config try/except block ────────────────────────── - H_RECONCILE --> M_CONFIGURING - M_CONFIGURING -->|"ConfigurationError\n(_get_dovecot_config)"| B_CONFIG - M_CONFIGURING -->|"StorageError: not mounted\n(ensure_storage_ready)"| B_LUKS_DISABLED - M_CONFIGURING -->|"StorageError: CalledProcessError\n(ensure_storage_ready)"| B_LUKS_FAILED - M_CONFIGURING -->|"StorageError: RuntimeError\n(ensure_storage_ready)"| B_LUKS_RT - M_CONFIGURING -->|"shutil.which('doveconf') is None"| SILENT - M_CONFIGURING -->|"all pass → _setup_tls"| B_TLS_NO_REL - M_CONFIGURING -->|"all pass → _setup_tls"| B_TLS_NO_CERT - M_CONFIGURING -->|"tls cert written → _setup_dovecot"| M_DOVECOT - - %% ── _reconcile: tls+dovecot+procmail try/except block ──────────────────── - M_DOVECOT -->|"ConfigurationError\n(doveconf -c fails)"| B_DOVECONF - M_DOVECOT -->|"validation OK\n→ service_reload(dovecot)"| M_DOVECOT_OK - M_DOVECOT_OK --> M_PROCMAIL - M_PROCMAIL -->|"ConfigurationError\n(postconf -e fails)"| B_POSTFIX - M_PROCMAIL -->|"service_reload(postfix) OK\n→ open_ports()"| ACTIVE - - %% ── Styles ─────────────────────────────────────────────────────────────── - classDef event fill:#dbeafe,stroke:#3b82f6,color:#1e3a5f - classDef handler fill:#f3f4f6,stroke:#6b7280,color:#111827 - classDef maint fill:#fef9c3,stroke:#ca8a04,color:#713f12 - classDef blocked fill:#fee2e2,stroke:#dc2626,color:#7f1d1d - classDef active fill:#dcfce7,stroke:#16a34a,color:#14532d - classDef silent fill:#f3f4f6,stroke:#9ca3af,color:#6b7280,stroke-dasharray:4 4 - - class EV_INSTALL,EV_CONFIG,EV_UPGRADE,EV_STOR_ATT,EV_STOR_DET,EV_CERT_AVAIL,EV_PEER,EV_ACTION event - class H_INSTALL,H_RECONCILE,H_PEER,H_ACTION handler - class M_INSTALLING,M_DEPS,M_DONE,M_CONFIGURING,M_DOVECOT,M_DOVECOT_OK,M_PROCMAIL maint - class B_CONFIG,B_LUKS_DISABLED,B_LUKS_FAILED,B_LUKS_RT,B_TLS_NO_REL,B_TLS_NO_CERT,B_DOVECONF,B_POSTFIX blocked - class ACTIVE active - class SILENT silent -``` - ---- - -## Diagram 2 — `_reconcile` Internal Call Chain - -Full execution path inside `_reconcile`, showing both `try/except` blocks and every branch. - -```mermaid -flowchart TD - START(["_reconcile(event) called\nconfig_changed / upgrade_charm /\nmail_data_storage_attached /\nmail_data_storage_detaching /\ncertificate_available /\n[via _on_install]"]) - - S1["unit.status =\nMaintenance('Configuring charm')"] - - %% ── try block 1: config + storage ─────────────────────────────────────── - TRY1[/"try"/] - - S2["_get_dovecot_config()\nDovecotConfig.from_charm()"] - S2_RAISES["raises ConfigurationError\n'Invalid charm configuration…'\n(mailname / postmaster-address /\nprimary-unit / luks-key)"] - - S3["ensure_storage_ready(charm)\nstorage.py"] - - S3A{"manage_luks = False"} - S3A_MT{"_mail_storage_mounted()\nos.path.ismount('/srv/mail')"} - S3A_RAISE["raises StorageError\n'mail-data not mounted;\nmanage-luks disabled'"] - S3A_OK["return (proceed)"] - - S3B{"manage_luks = True\nshutil.which('cryptsetup')"} - S3B_NONE["None → log warning\nreturn (defer silently)"] - - S3C{"storages / dev_path\nvalid?"} - S3C_BAD["empty or None\nlog error\nreturn (no block)"] - - S3D["setup_luks_storage(luks_key, dev_path)"] - S3D_STEPS["① isLuks check\n② luksFormat if new (key via stdin)\n③ cryptsetup open if not mapped\n④ dmsetup mknodes\n⑤ blkid check for ext4\n⑥ mkfs.ext4 if no fs\n⑦ configure_file /etc/fstab\n⑧ mount → /srv/mail"] - S3D_CPE["CalledProcessError\nraises StorageError\n'Failed to setup LUKS storage'"] - S3D_RTE["RuntimeError\nraises StorageError(str(e))"] - S3D_OK["return (LUKS ready)"] - - S3E["teardown_detaching_storage(charm)"] - S3E_STEPS["if storages present → return (no-op)\nif storages gone:\n manage_luks + mounted → umount\n mapper exists → luksClose\nCalledProcessError → log only"] - - CATCH1["except CharmBlockedError as e\nunit.status = Blocked(str(e))\nreturn"] - - %% ── doveconf guard ─────────────────────────────────────────────────────── - S4{"shutil.which('doveconf')"} - S4_NONE["log warning\n'Dovecot not installed yet'\nreturn\n(stays in Maintenance\n'Configuring charm')"] - - %% ── try block 2: tls + dovecot + procmail ─────────────────────────────── - TRY2[/"try"/] - - S5TLS["_setup_tls(dovecot_config)"] - S5TLS_NO_REL["raises ConfigurationError\n'TLS certificates relation\nnot available…'"] - S5TLS_NO_CERT["raises ConfigurationError\n'TLS certificate not yet\navailable…'"] - S5TLS_OK["write cert → /etc/dovecot/private/.pem (0o644)\nwrite key → /etc/dovecot/private/.key (0o600)"] - - S5A["_setup_dovecot(dovecot_config)"] - S5A_1["unit.status =\nMaintenance('Setting up and\nconfiguring dovecot')"] - S5A_2["render dovecot.conf.tmpl\n(ssl=required, mailname cert paths)\nwrite → /etc/dovecot/conf.d/\n99-local-dovecot-charm.conf"] - S5A_3{"doveconf -c\n/etc/dovecot/conf.d/\n99-local-dovecot-charm.conf"} - S5A_RAISE["raises ConfigurationError\n'Invalid Dovecot configuration,\ncheck logs for details'"] - S5A_OK["service_reload('dovecot',\nrestart_on_failure=True)\nunit.status =\nMaintenance('Dovecot\nconfiguration updated')"] - - S5B["_setup_procmail()"] - S5B_1["unit.status =\nMaintenance('Setting up and\nconfiguring procmail')"] - S5B_2["mkdir /srv/mail (0o1777)\nrender procmailrc.tmpl\nwrite → /etc/procmailrc"] - S5B_3{"postconf -e\nmailbox_command=procmail…"} - S5B_RAISE["raises ConfigurationError\n'Failed to configure\npostfix: '"] - S5B_OK["service_reload('postfix',\nrestart_on_failure=True)"] - - CATCH2["except ConfigurationError as e\nunit.status = Blocked(str(e))\nreturn"] - - S5C["_open_ports()\ntcp: 143, 993, 110, 995, 4190, 9900"] - ACTIVE(["unit.status = Active()"]) - - %% ── Wiring ─────────────────────────────────────────────────────────────── - START --> S1 --> TRY1 --> S2 - S2 -->|"raises"| S2_RAISES - S2 -->|"ok"| S3 - - S3 --> S3A - S3A -->|"True"| S3A_MT - S3A_MT -->|"not mounted"| S3A_RAISE - S3A_MT -->|"mounted"| S3A_OK - - S3A -->|"False (manage_luks=True)"| S3B - S3B -->|"None"| S3B_NONE - S3B -->|"found"| S3C - S3C -->|"invalid"| S3C_BAD - S3C -->|"valid"| S3D - S3D --> S3D_STEPS - S3D_STEPS -->|"CalledProcessError"| S3D_CPE - S3D_STEPS -->|"RuntimeError"| S3D_RTE - S3D_STEPS -->|"success"| S3D_OK - - S3A_OK & S3B_NONE & S3C_BAD & S3D_OK --> S3E - S3E --> S3E_STEPS - - S2_RAISES & S3A_RAISE & S3D_CPE & S3D_RTE --> CATCH1 - - S3E_STEPS --> S4 - S4 -->|"None"| S4_NONE - S4 -->|"found"| TRY2 - - TRY2 --> S5TLS - S5TLS -->|"_tls is None"| S5TLS_NO_REL - S5TLS -->|"get_assigned_certificate\nreturns (None,None)"| S5TLS_NO_CERT - S5TLS -->|"cert+key obtained"| S5TLS_OK --> S5A - S5A --> S5A_1 --> S5A_2 --> S5A_3 - S5A_3 -->|"non-zero exit"| S5A_RAISE - S5A_3 -->|"exit 0"| S5A_OK --> S5B - S5B --> S5B_1 --> S5B_2 --> S5B_3 - S5B_3 -->|"CalledProcessError"| S5B_RAISE - S5B_3 -->|"success"| S5B_OK - - S5TLS_NO_REL & S5TLS_NO_CERT & S5A_RAISE & S5B_RAISE --> CATCH2 - - S5B_OK --> S5C --> ACTIVE - - %% ── Styles ─────────────────────────────────────────────────────────────── - classDef tryblock fill:#ede9fe,stroke:#7c3aed,color:#3b0764 - classDef catch fill:#fee2e2,stroke:#dc2626,color:#7f1d1d - classDef decision fill:#e0f2fe,stroke:#0284c7,color:#0c4a6e - classDef action fill:#f3f4f6,stroke:#6b7280,color:#111827 - classDef maint fill:#fef9c3,stroke:#ca8a04,color:#713f12 - classDef blocked fill:#fee2e2,stroke:#dc2626,color:#7f1d1d - classDef active fill:#dcfce7,stroke:#16a34a,color:#14532d - classDef silent fill:#f3f4f6,stroke:#9ca3af,color:#6b7280,stroke-dasharray:4 4 - classDef start fill:#dbeafe,stroke:#3b82f6,color:#1e3a5f - classDef raises fill:#fef3c7,stroke:#d97706,color:#78350f - - class START start - class TRY1,TRY2 tryblock - class CATCH1,CATCH2 catch - class S3A,S3A_MT,S3B,S3C,S5A_3,S5B_3,S4,S5TLS decision - class S2,S3,S3D,S3D_STEPS,S3E,S3E_STEPS,S5TLS_OK,S5A,S5A_1,S5A_2,S5A_OK,S5B,S5B_1,S5B_2,S5B_OK,S5C action - class S5A_1,S5B_1 maint - class S2_RAISES,S3A_RAISE,S3D_CPE,S3D_RTE,S5TLS_NO_REL,S5TLS_NO_CERT,S5A_RAISE,S5B_RAISE raises - class ACTIVE active - class S3B_NONE,S3C_BAD,S4_NONE silent -``` - ---- - -## Notes - -- **`_on_install`** no longer guards on config — just installs packages then calls `_reconcile`. Config blocking handled entirely inside `_reconcile`. -- **`_configure` deleted** — inlined into `_reconcile` as second `try/except` block. -- **`certificate_available` wired to `_reconcile`** — same handler as all other events. No separate `_on_certificate_available`. -- **TLS is mandatory**: `ssl = required` always in dovecot.conf. The charm will not reach `ActiveStatus` without a working `certificates` relation that has issued a cert. -- **`_setup_tls`** runs first in the second try block — writes cert+key from relation data to `/etc/dovecot/private/` before dovecot config is rendered or validated. -- **Status written only in `_reconcile`** catch blocks (and transient Maintenance in individual setup methods). No function outside `_reconcile`/`_on_install` writes Blocked directly. -- **Exception hierarchy:** `StorageError` and `ConfigurationError` both extend `CharmBlockedError`. First `try/except` catches `CharmBlockedError` (both types). Second catches `ConfigurationError` only. -- **`teardown_detaching_storage`** never raises — `CalledProcessError` during umount/luksClose is logged and swallowed. Not in either try block. -- **Silent hang** remains: if `doveconf` absent, unit stays in `Maintenance("Configuring charm")` until next event. -- **No `WaitingStatus`** used anywhere. -- **LUKS key** fetched from Juju secret at config-validation time; passed to `cryptsetup` via stdin. From bf98cee40f3ad492824b1f7b34899fb32c9224e2 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 09:19:41 +0300 Subject: [PATCH 12/39] refactor(tls): streamline TLS integration in tests and remove redundant deployment logic --- dovecot-charm/pyproject.toml | 8 +----- dovecot-charm/tests/integration/conftest.py | 21 +++++++++++++++- dovecot-charm/tests/integration/test_tls.py | 28 +-------------------- dovecot-charm/tests/unit/test_tls.py | 1 - 4 files changed, 22 insertions(+), 36 deletions(-) diff --git a/dovecot-charm/pyproject.toml b/dovecot-charm/pyproject.toml index f427468..2848262 100644 --- a/dovecot-charm/pyproject.toml +++ b/dovecot-charm/pyproject.toml @@ -122,13 +122,7 @@ lint.per-file-ignores."tests/*" = [ "D417", "S", ] -lint.per-file-ignores."src/charms/*" = [ - "B006", - "B028", - "RUF100", - "S101", - "SIM103", -] + lint.flake8-copyright.author = "Canonical Ltd." lint.flake8-copyright.min-file-size = 1 lint.flake8-copyright.notice-rgx = "Copyright\\s\\d{4}([-,]\\d{4})*\\s+" diff --git a/dovecot-charm/tests/integration/conftest.py b/dovecot-charm/tests/integration/conftest.py index 52861a5..c056db0 100644 --- a/dovecot-charm/tests/integration/conftest.py +++ b/dovecot-charm/tests/integration/conftest.py @@ -47,6 +47,7 @@ def charm_fixture(pytestconfig: pytest.Config) -> str: def dovecot_charm( charm: str, juju: jubilant.Juju, + tls_charm: str, ) -> str: """Build and deploy the charm.""" logging.info(f"Checking for existing application {APP_NAME}...") @@ -74,10 +75,16 @@ def dovecot_charm( trust=True, ) + try: + logging.info("Adding TLS relation...") + juju.integrate(f"{dovecot_charm}:certificates", f"{tls_charm}:certificates") + except Exception: + logging.info("TLS relation already there...") + juju.cli("grant-secret", "dovecot-luks-key", APP_NAME) logging.info("Waiting for active status...") juju.wait( - lambda status: status.apps[APP_NAME].is_active, + lambda status: status.apps[APP_NAME].is_active and status.apps[tls_charm].is_active, timeout=10 * 60, ) return APP_NAME @@ -116,3 +123,15 @@ def dovecot_charm_manual_storage( timeout=10 * 60, ) return charm_name + + +@pytest.fixture(scope="module") +def tls_charm(juju: jubilant.Juju) -> str: + tls_app = "self-signed-certificates" + if tls_app not in juju.status().apps: + logging.info("Deploying self-signed-certificates...") + juju.deploy(tls_app, channel="latest/stable") + else: + logging.info(f"{tls_app} already deployed, skipping deployment.") + + return tls_app diff --git a/dovecot-charm/tests/integration/test_tls.py b/dovecot-charm/tests/integration/test_tls.py index b2e147d..aa223a4 100644 --- a/dovecot-charm/tests/integration/test_tls.py +++ b/dovecot-charm/tests/integration/test_tls.py @@ -5,31 +5,6 @@ import logging import ssl -import jubilant -import pytest - -TLS_APP = "self-signed-certificates" - - -@pytest.fixture(scope="module") -def deploy_with_tls(juju: jubilant.Juju, dovecot_charm: str): - if TLS_APP not in juju.status().apps: - logging.info("Deploying self-signed-certificates...") - juju.deploy(TLS_APP, channel="latest/stable") - else: - logging.info(f"{TLS_APP} already deployed, skipping deployment.") - - try: - logging.info("Adding TLS relation...") - juju.integrate(f"{dovecot_charm}:certificates", f"{TLS_APP}:certificates") - except Exception: - logging.info("TLS relation already there...") - - # The charm is Blocked without a certificate; wait until it becomes Active - # (meaning _setup_tls succeeded and cert files are written). - logging.info("Waiting for active/idle status...") - juju.wait(jubilant.all_active, timeout=1200) - def test_tls_certificate_files_written(juju, dovecot_charm, deploy_with_tls): """Verify that TLS certificate and key files are written to the unit.""" @@ -84,14 +59,13 @@ def test_tls_certificate_content_valid(juju, dovecot_charm, deploy_with_tls): def test_tls_dovecot_config_references_cert(juju, dovecot_charm, deploy_with_tls): - """Verify dovecot configuration uses ssl=required and references the cert.""" + """Verify dovecot configuration references the cert.""" unit_name = f"{dovecot_charm}/0" dovecot_conf = juju.exec( "cat", "/etc/dovecot/conf.d/99-local-dovecot-charm.conf", unit=unit_name ) logging.info("Checking dovecot SSL configuration...") - assert "ssl = required" in dovecot_conf.stdout assert "ssl_cert" in dovecot_conf.stdout assert "example.com" in dovecot_conf.stdout assert "ssl_min_protocol = TLSv1.2" in dovecot_conf.stdout diff --git a/dovecot-charm/tests/unit/test_tls.py b/dovecot-charm/tests/unit/test_tls.py index 7e0615c..f654cbc 100644 --- a/dovecot-charm/tests/unit/test_tls.py +++ b/dovecot-charm/tests/unit/test_tls.py @@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch import ops -import ops.testing import pytest From 6d24bfe6a415444f88454c024cf8fbb683a3d960 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 09:41:10 +0300 Subject: [PATCH 13/39] test(tls): add TLS tests and remove unused deploy_with_tls parameter --- .github/workflows/integration_test.yaml | 1 + dovecot-charm/tests/integration/conftest.py | 13 +++++++++---- dovecot-charm/tests/integration/test_tls.py | 10 +++++----- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/.github/workflows/integration_test.yaml b/.github/workflows/integration_test.yaml index 9b52c8b..92261d6 100644 --- a/.github/workflows/integration_test.yaml +++ b/.github/workflows/integration_test.yaml @@ -25,6 +25,7 @@ jobs: "test_config.py", "test_mail.py", "test_storage.py", + "test_tls.py", ] allure-report: if: ${{ !cancelled() && github.event_name == 'schedule' }} diff --git a/dovecot-charm/tests/integration/conftest.py b/dovecot-charm/tests/integration/conftest.py index c056db0..e1b3fc6 100644 --- a/dovecot-charm/tests/integration/conftest.py +++ b/dovecot-charm/tests/integration/conftest.py @@ -74,14 +74,12 @@ def dovecot_charm( constraints={"virt-type": "virtual-machine"}, trust=True, ) - + juju.cli("grant-secret", "dovecot-luks-key", APP_NAME) try: logging.info("Adding TLS relation...") - juju.integrate(f"{dovecot_charm}:certificates", f"{tls_charm}:certificates") + juju.integrate(f"{APP_NAME}:certificates", f"{tls_charm}:certificates") except Exception: logging.info("TLS relation already there...") - - juju.cli("grant-secret", "dovecot-luks-key", APP_NAME) logging.info("Waiting for active status...") juju.wait( lambda status: status.apps[APP_NAME].is_active and status.apps[tls_charm].is_active, @@ -94,6 +92,7 @@ def dovecot_charm( def dovecot_charm_manual_storage( charm: str, juju: jubilant.Juju, + tls_charm: str, ) -> str: """Build and deploy the charm.""" charm_name = f"{APP_NAME}-manual" @@ -117,6 +116,12 @@ def dovecot_charm_manual_storage( trust=True, ) + try: + logging.info("Adding TLS relation...") + juju.integrate(f"{dovecot_charm}:certificates", f"{tls_charm}:certificates") + except Exception: + logging.info("TLS relation already there...") + logging.info("Waiting for blocked status...") juju.wait( lambda status: status.apps[charm_name].is_blocked, diff --git a/dovecot-charm/tests/integration/test_tls.py b/dovecot-charm/tests/integration/test_tls.py index aa223a4..517ed57 100644 --- a/dovecot-charm/tests/integration/test_tls.py +++ b/dovecot-charm/tests/integration/test_tls.py @@ -6,7 +6,7 @@ import ssl -def test_tls_certificate_files_written(juju, dovecot_charm, deploy_with_tls): +def test_tls_certificate_files_written(juju, dovecot_charm): """Verify that TLS certificate and key files are written to the unit.""" unit_name = f"{dovecot_charm}/0" logging.info(f"Targeting unit: {unit_name}") @@ -22,7 +22,7 @@ def test_tls_certificate_files_written(juju, dovecot_charm, deploy_with_tls): assert "example.com.key" in key_check.stdout, "Key file not found" -def test_tls_certificate_permissions(juju, dovecot_charm, deploy_with_tls): +def test_tls_certificate_permissions(juju, dovecot_charm): """Verify correct file permissions on TLS cert and key.""" unit_name = f"{dovecot_charm}/0" @@ -43,7 +43,7 @@ def test_tls_certificate_permissions(juju, dovecot_charm, deploy_with_tls): ) -def test_tls_certificate_content_valid(juju, dovecot_charm, deploy_with_tls): +def test_tls_certificate_content_valid(juju, dovecot_charm): """Verify the certificate file contains a valid PEM certificate.""" unit_name = f"{dovecot_charm}/0" @@ -58,7 +58,7 @@ def test_tls_certificate_content_valid(juju, dovecot_charm, deploy_with_tls): ) -def test_tls_dovecot_config_references_cert(juju, dovecot_charm, deploy_with_tls): +def test_tls_dovecot_config_references_cert(juju, dovecot_charm): """Verify dovecot configuration references the cert.""" unit_name = f"{dovecot_charm}/0" @@ -71,7 +71,7 @@ def test_tls_dovecot_config_references_cert(juju, dovecot_charm, deploy_with_tls assert "ssl_min_protocol = TLSv1.2" in dovecot_conf.stdout -def test_tls_dovecot_ssl_port_responds(juju, dovecot_charm, deploy_with_tls): +def test_tls_dovecot_ssl_port_responds(juju, dovecot_charm): """Verify dovecot responds on the SSL IMAP port (993).""" unit_name = f"{dovecot_charm}/0" status = juju.status() From 96ad9706a1f9d43f30827329c29e8712fdfb1b58 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 10:10:58 +0300 Subject: [PATCH 14/39] feat(dovecot-charm): add HA support with SSH key exchange and force-sync action --- docs/release-notes/artifacts/pr-4-ha.yaml | 4 + docs/release-notes/index.rst | 1 + docs/release-notes/release-notes-0005.rst | 57 ++++++++ dovecot-charm/charmcraft.yaml | 2 + dovecot-charm/src/charm.py | 127 ++++++++++++++++++ .../templates/sync-to-secondary.sh.tmpl | 20 +++ .../templates/sync-to-secondary_cron.tmpl | 3 + dovecot-charm/tests/integration/test_ha.py | 89 ++++++++++++ dovecot-charm/tests/unit/test_charm.py | 87 +++++++++++- 9 files changed, 389 insertions(+), 1 deletion(-) create mode 100644 docs/release-notes/artifacts/pr-4-ha.yaml create mode 100644 docs/release-notes/release-notes-0005.rst create mode 100644 dovecot-charm/templates/sync-to-secondary.sh.tmpl create mode 100644 dovecot-charm/templates/sync-to-secondary_cron.tmpl create mode 100644 dovecot-charm/tests/integration/test_ha.py diff --git a/docs/release-notes/artifacts/pr-4-ha.yaml b/docs/release-notes/artifacts/pr-4-ha.yaml new file mode 100644 index 0000000..300a046 --- /dev/null +++ b/docs/release-notes/artifacts/pr-4-ha.yaml @@ -0,0 +1,4 @@ +name: pr-4-ha +type: major +summary: HA support with SSH key exchange and force-sync action +url: https://github.com/canonical/mailserver-operators/pull/4 diff --git a/docs/release-notes/index.rst b/docs/release-notes/index.rst index 23c635d..9ca49e0 100644 --- a/docs/release-notes/index.rst +++ b/docs/release-notes/index.rst @@ -35,3 +35,4 @@ Releases release-notes-0002 release-notes-0003 release-notes-0004 + release-notes-0005 diff --git a/docs/release-notes/release-notes-0005.rst b/docs/release-notes/release-notes-0005.rst new file mode 100644 index 0000000..9a5ac16 --- /dev/null +++ b/docs/release-notes/release-notes-0005.rst @@ -0,0 +1,57 @@ +.. _release_notes_release_notes_0005: + +Dovecot release notes – 2.3/edge +================================= + +These release notes cover new features and changes in Dovecot. + +Main features: + +* Added HA support with SSH key exchange and ``force-sync`` action. + +See our :ref:`Release policy and schedule `. + +Requirements and compatibility +------------------------------- + +The charm operates Dovecot 2.3. + +.. list-table:: + :header-rows: 1 + :widths: 50 50 + + * - Software + - Required version + * - Juju + - 3.x + * - Ubuntu + - 24.04 + +Updates +------- + +The following major and minor features were added in this release. + +HA support with SSH key exchange and force-sync action +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +High-availability support was added to the Dovecot charm. The charm now +exchanges SSH keys between primary and secondary units during installation, +enabling passwordless root SSH access required for mail pool synchronisation. +A new ``force-sync`` action was introduced, allowing operators to trigger an +immediate synchronisation of the mail pool from the primary unit to the +secondary unit on demand. + +Relevant links: + +* `PR `_ + +Bug fixes +--------- + +No bug fixes in this release. + +Known issues +------------ + +No known issues. diff --git a/dovecot-charm/charmcraft.yaml b/dovecot-charm/charmcraft.yaml index 7ef911e..616d42e 100644 --- a/dovecot-charm/charmcraft.yaml +++ b/dovecot-charm/charmcraft.yaml @@ -97,3 +97,5 @@ actions: description: The queue to clear (deferred or all). default: deferred enum: [deferred, all] + force-sync: + description: Manually trigger synchronization of mail to the secondary unit. diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index 0b5fddf..7219f42 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -5,7 +5,9 @@ """Dovecot charm.""" import logging +import os import shutil +import socket import subprocess # nosec import typing from pathlib import Path @@ -57,6 +59,8 @@ def __init__(self, *args): self.framework.observe(self.on.clear_queue_action, self._on_clear_queue_action) self.framework.observe(self.on.mail_data_storage_attached, self._reconcile) self.framework.observe(self.on.mail_data_storage_detaching, self._reconcile) + self.framework.observe(self.on.replicas_relation_changed, self._on_replicas_changed) + self.framework.observe(self.on.force_sync_action, self._on_force_sync) self.framework.observe( self.on[PEER_RELATION_NAME].relation_created, @@ -67,6 +71,13 @@ def __init__(self, *args): loader=jinja2.FileSystemLoader(TEMPLATES_DIR), autoescape=True ) + # Sync to secondary + self.sync_smtp_aliases_target = "/usr/local/bin/sync-smtp-aliases.sh" + self.sync_to_secondary_target = "/usr/local/bin/sync-to-secondary.sh" + self.sync_to_secondary_cronjob_target = "/etc/cron.d/sync-to-secondary" + self.sync_to_secondary_template = "sync-to-secondary.sh.tmpl" + self.sync_to_secondary_cronjob_template = "sync-to-secondary_cron.tmpl" + # TLS certificates integration self._tls = None mailname = self.config.get("mailname", "") @@ -107,6 +118,11 @@ def _on_peer_relation_created(self, event): relation_data = event.relation.data[self.unit] relation_data["unit-name"] = self.unit.name + @property + def _is_primary(self): + """Return True if this unit is the configured primary unit.""" + return self.unit.name == self.config.get("primary-unit", "") + def _get_dovecot_config(self) -> DovecotConfig: """Craft the DovecotConfig from charm configuration and validate it. @@ -163,6 +179,10 @@ def _install(self): apt.update() apt.add_package(REQUIRED_PACKAGES) shutil.copy(HOSTNAME_FILE, MAILNAME_FILE) + self._setup_ssh_keys() + if self._is_primary: + self._install_mail_sync_script() + self._setup_mail_sync_cronjob() self.unit.status = MaintenanceStatus("Charm installation done") def _open_ports(self): @@ -271,6 +291,113 @@ def _on_clear_queue_action(self, event): logger.exception(f"Failed to clear Postfix queue: {e.stderr}") event.fail(f"Failed to run postsuper: {e.stderr}") + @property + def _secondary_hostname(self): + """Return the hostname/IP of the secondary unit.""" + relation = self.model.get_relation("replicas") + if not relation: + return None + + for unit in relation.units: + return ( + relation.data[unit].get("hostname") + or relation.data[unit].get("private-address") + or relation.data[unit].get("ingress-address") + ) + + return None + + def _setup_ssh_keys(self): + """Generate SSH key and share public key via peer relation.""" + ssh_dir = Path("/root/.ssh") + ssh_dir.mkdir(mode=0o700, exist_ok=True) + key_file = ssh_dir / "id_ed25519" + + if not key_file.exists(): + logger.warning("keyfile not there") + os.system(f'ssh-keygen -t ed25519 -N "" -f {key_file}') # noqa: S605 + + pub_key = (ssh_dir / "id_ed25519.pub").read_text().strip() + relation = self.model.get_relation("replicas") + if relation: + relation.data[self.unit]["public_key"] = pub_key + relation.data[self.unit]["hostname"] = socket.gethostname() + + config_file = ssh_dir / "config" + if not config_file.exists(): + config_file.write_text("Host *\n StrictHostKeyChecking no\n") + config_file.chmod(0o600) + + def _on_replicas_changed(self, event): + """Handle replicas relation changed — sync SSH authorized_keys.""" + authorized_keys = [] + relation = self.model.get_relation("replicas") + + for unit in relation.units: + pk = relation.data[unit].get("public_key") + if pk: + authorized_keys.append(pk) + + our_pk = relation.data[self.unit].get("public_key") + if our_pk: + authorized_keys.append(our_pk) + + auth_file = Path("/root/.ssh/authorized_keys") + auth_file.write_text("\n".join(authorized_keys)) + auth_file.chmod(0o600) + + self._ensure_root_ssh_configs() + + def _ensure_root_ssh_configs(self): + """Ensure PermitRootLogin is set in sshd_config.""" + cmd = "sed -i 's/^#*PermitRootLogin.*/PermitRootLogin prohibit-password/' /etc/ssh/sshd_config" + os.system(cmd) # noqa: S605 + os.system("systemctl restart ssh") # noqa: S605, S607 + + def _install_mail_sync_script(self): + """Install mail pool synchronization script.""" + self.unit.status = MaintenanceStatus("Installing mail pool synchronization script") + template_context = { + "secondary_hostname": self._secondary_hostname, + "mail_root": MAIL_ROOT, + } + template = self.jinja.get_template(self.sync_to_secondary_template) + contents = template.render(template_context) + host.write_file(self.sync_to_secondary_target, contents, perms=0o755) + self.unit.status = MaintenanceStatus("Mail pool synchronization installed") + + def _setup_mail_sync_cronjob(self): + """Set up mail pool synchronization cronjob.""" + self.unit.status = MaintenanceStatus("Setting up mail pool synchronization cronjob") + template_context = { + "schedule": self.config.get("sync-schedule", "*/30 * * * *"), + } + template = self.jinja.get_template(self.sync_to_secondary_cronjob_template) + contents = template.render(template_context) + host.write_file(self.sync_to_secondary_cronjob_target, contents, perms=0o644) + systemd.service_restart("cron") + self.unit.status = MaintenanceStatus("Mail pool synchronization cronjob has been set up") + + def _on_force_sync(self, event): + """Force synchronization with secondary unit.""" + if not self._is_primary: + event.fail("This action can only be run on the primary unit.") + return + + if not self._secondary_hostname: + event.fail("No secondary unit found to sync to.") + return + + try: + cmd = [self.sync_to_secondary_target] + logger.info(f"Running manual sync: {' '.join(cmd)}") + subprocess.run(cmd, check=True, capture_output=True, text=True) + event.set_results({"result": "Sync completed successfully"}) + except subprocess.CalledProcessError as e: + msg = f"Sync failed: {e.stderr}" + logger.error(msg) + event.fail(msg) + def _setup_tls(self, dovecot_config: DovecotConfig) -> None: """Write TLS cert+key to disk from the certificates relation. diff --git a/dovecot-charm/templates/sync-to-secondary.sh.tmpl b/dovecot-charm/templates/sync-to-secondary.sh.tmpl new file mode 100644 index 0000000..126ad02 --- /dev/null +++ b/dovecot-charm/templates/sync-to-secondary.sh.tmpl @@ -0,0 +1,20 @@ +#!/bin/bash + +set -eu + +# Sync using doveadm (dsync) for users that have a Maildir. +# Avoids syncing system accounts without mailboxes. +remote="remote:root@{{ secondary_hostname }}" +found=0 +for user_dir in "{{ mail_root }}"/*; do + if [ -d "$user_dir/Maildir" ]; then + user="$(basename "$user_dir")" + doveadm backup -u "$user" "$remote" + found=1 + fi +done +if [ "$found" -eq 0 ]; then + echo "No Maildir found under {{ mail_root }}; nothing to sync." >&2 + exit 1 +fi +touch {{ mail_root }}/.last-dsync diff --git a/dovecot-charm/templates/sync-to-secondary_cron.tmpl b/dovecot-charm/templates/sync-to-secondary_cron.tmpl new file mode 100644 index 0000000..d3101cb --- /dev/null +++ b/dovecot-charm/templates/sync-to-secondary_cron.tmpl @@ -0,0 +1,3 @@ +{{ schedule }} root /usr/local/bin/sync-to-secondary.sh >> /var/log/sync-to-secondary.log 2>&1 + +# End of file diff --git a/dovecot-charm/tests/integration/test_ha.py b/dovecot-charm/tests/integration/test_ha.py new file mode 100644 index 0000000..c20dccb --- /dev/null +++ b/dovecot-charm/tests/integration/test_ha.py @@ -0,0 +1,89 @@ +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +import logging +from typing import cast + +import jubilant +import pytest + + +def _get_unit_hostname(status, app_name, unit_name): + """Helper to get unit hostname from status.""" + try: + machine = status.apps[app_name].units[unit_name].machine + return status.machines[machine].hostname + except KeyError: + logging.error(f"Unit {unit_name} not found in status.") + return None + + +@pytest.mark.timeout(1800) +def test_ha_failover(juju, dovecot_charm): + status = juju.status() + if len(status.apps[dovecot_charm].units) < 2: + logging.info("Adding the second unit...") + juju.add_unit(dovecot_charm, num_units=1) + + def two_units_active(status): + app = status.apps.get(dovecot_charm) + if not app: + return False + if len(app.units) < 2: + return False + return jubilant.all_active(status) + + logging.info("Waiting for 2 units to be active...") + juju.wait(two_units_active, timeout=600) + + status = juju.status() + units = list(status.apps[dovecot_charm].units.keys()) + units.sort(key=lambda x: int(x.split("/")[-1])) + + primary = units[0] + secondary = units[1] + + logging.info(f"Primary: {primary}, Secondary: {secondary}") + + juju.config(dovecot_charm, {"primary-unit": primary}) + juju.wait(jubilant.all_active, timeout=300) + + logging.info("Verifying SSH key exchange...") + + cmd = "cat /root/.ssh/authorized_keys | wc -l" + + result_primary = juju.exec(cmd, unit=primary) + logging.info(f"Primary authorized_keys count: {result_primary.stdout.strip()}") + assert int(result_primary.stdout.strip()) >= 1 + + result_secondary = juju.exec(cmd, unit=secondary) + logging.info(f"Secondary authorized_keys count: {result_secondary.stdout.strip()}") + assert int(result_secondary.stdout.strip()) >= 1 + + logging.info("Verifying sync script on Primary...") + + status = juju.status() + secondary_hostname = _get_unit_hostname(status, dovecot_charm, secondary) + logging.info(f"Secondary hostname: {secondary_hostname}") + + script_path = "/usr/local/bin/sync-to-secondary.sh" + cmd = f"cat {script_path}" + script_content = juju.exec(cmd, unit=primary).stdout + + logging.info(f"Sync script content on Primary:\n{script_content}") + assert secondary_hostname in script_content, ( + "Secondary hostname not found in sync script on Primary" + ) + + logging.info("Running force-sync on Primary...") + + task = juju.run(unit=primary, action="force-sync", wait=100) + assert task.status == "completed" + assert task.results["result"] == "Sync completed successfully" + + with pytest.raises(jubilant.TaskError) as exc_info: + juju.run(unit=secondary, action="force-sync", wait=100) + assert cast(jubilant.TaskError, exc_info.value).task.status == "failed" + logging.info("force-sync on Secondary correctly failed.") + + logging.info("HA Failover test passed.") diff --git a/dovecot-charm/tests/unit/test_charm.py b/dovecot-charm/tests/unit/test_charm.py index 9691eec..e697d98 100644 --- a/dovecot-charm/tests/unit/test_charm.py +++ b/dovecot-charm/tests/unit/test_charm.py @@ -1,12 +1,14 @@ # Copyright 2026 Canonical Ltd. # See LICENSE file for licensing details. +import dataclasses from subprocess import CalledProcessError # nosec -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import ops import ops.testing import pytest +from charm import DovecotCharm from exceptions import ConfigurationError @@ -126,3 +128,86 @@ def test_clear_queue_failure(ctx, base_state): base_state, ) assert "postsuper" in exc_info.value.message + + +def test_install_calls_all_setup_steps(ctx, base_state): + with ( + patch("charm.apt") as mock_apt, + patch("charm.shutil.copy") as mock_copy, + patch("charm.DovecotCharm._open_ports") as mock_open_ports, + patch("charm.DovecotCharm._setup_dovecot") as mock_dovecot, + patch("charm.DovecotCharm._setup_procmail") as mock_procmail, + patch("charm.DovecotCharm._setup_ssh_keys") as mock_setup_ssh, + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), + ): + ctx.run(ctx.on.install(), base_state) + + mock_apt.update.assert_called_once() + mock_apt.add_package.assert_called_once() + mock_copy.assert_called_once_with("/etc/hostname", "/etc/mailname") + mock_open_ports.assert_called_once() + mock_dovecot.assert_called_once() + mock_procmail.assert_called_once() + mock_setup_ssh.assert_called_once() + + +def test_is_primary_true(ctx, base_state): + with patch("charm.DovecotCharm._install"), ctx(ctx.on.config_changed(), base_state) as mgr: + assert mgr.charm._is_primary is True + + +def test_is_primary_false(ctx, base_state): + state_in = dataclasses.replace( + base_state, config={**base_state.config, "primary-unit": "dovecot-charm/999"} + ) + with patch("charm.DovecotCharm._install"), ctx(ctx.on.config_changed(), state_in) as mgr: + assert mgr.charm._is_primary is False + + +def test_force_sync_success(ctx, base_state): + mock_result = MagicMock(stdout="ok", stderr="") + with ( + patch("charm.subprocess.run", return_value=mock_result), + patch.object( + DovecotCharm, + "_secondary_hostname", + new_callable=PropertyMock, + return_value="10.0.0.2", + ), + ): + ctx.run(ctx.on.action("force-sync"), base_state) + assert ctx.action_results == {"result": "Sync completed successfully"} + + +def test_force_sync_not_primary(ctx, base_state): + state_in = dataclasses.replace( + base_state, config={**base_state.config, "primary-unit": "dovecot-charm/999"} + ) + with pytest.raises(ops.testing.ActionFailed) as exc_info: + ctx.run(ctx.on.action("force-sync"), state_in) + assert "primary unit" in exc_info.value.message + + +def test_force_sync_no_secondary(ctx, base_state): + with pytest.raises(ops.testing.ActionFailed) as exc_info: + ctx.run(ctx.on.action("force-sync"), base_state) + assert "secondary" in exc_info.value.message + + +def test_force_sync_subprocess_failure(ctx, base_state): + with ( + patch( + "charm.subprocess.run", + side_effect=CalledProcessError(1, "sync", stderr="fail"), + ), + patch.object( + DovecotCharm, + "_secondary_hostname", + new_callable=PropertyMock, + return_value="10.0.0.2", + ), + pytest.raises(ops.testing.ActionFailed) as exc_info, + ): + ctx.run(ctx.on.action("force-sync"), base_state) + assert "fail" in exc_info.value.message From 6c78502c4f2f253b995adcc077234ab3f058d768 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 10:10:59 +0300 Subject: [PATCH 15/39] docs: add release notes for pr/4-ha From 2224e3aadd4d2b36d115675af2d97e0236a938cd Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 11:20:37 +0300 Subject: [PATCH 16/39] refactor(tests): clean up TLS test cases by removing unused test and redundant imports --- dovecot-charm/pyproject.toml | 1 - dovecot-charm/tests/unit/test_tls.py | 13 +------------ 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/dovecot-charm/pyproject.toml b/dovecot-charm/pyproject.toml index 2848262..bbee4c4 100644 --- a/dovecot-charm/pyproject.toml +++ b/dovecot-charm/pyproject.toml @@ -122,7 +122,6 @@ lint.per-file-ignores."tests/*" = [ "D417", "S", ] - lint.flake8-copyright.author = "Canonical Ltd." lint.flake8-copyright.min-file-size = 1 lint.flake8-copyright.notice-rgx = "Copyright\\s\\d{4}([-,]\\d{4})*\\s+" diff --git a/dovecot-charm/tests/unit/test_tls.py b/dovecot-charm/tests/unit/test_tls.py index f654cbc..fdbd869 100644 --- a/dovecot-charm/tests/unit/test_tls.py +++ b/dovecot-charm/tests/unit/test_tls.py @@ -2,8 +2,8 @@ # See LICENSE file for licensing details. """Unit tests for TLS certificate integration.""" -import dataclasses from unittest.mock import MagicMock, patch +from exceptions import ConfigurationError import ops import pytest @@ -26,15 +26,6 @@ def test_no_tls_cert_yet_blocks(ctx, base_state): assert "certificate" in state_out.unit_status.message.lower() -def test_no_tls_relation_blocks(ctx, base_state): - """Charm must be Blocked when mailname is empty (so _tls is None).""" - state_in = dataclasses.replace(base_state, config={**base_state.config, "mailname": ""}) - # Empty mailname → _tls is None AND DovecotConfig validation fails first - # (mailname is required by pydantic); either way the charm must be Blocked - state_out = ctx.run(ctx.on.config_changed(), state_in) - assert isinstance(state_out.unit_status, ops.BlockedStatus) - - def test_setup_tls_writes_cert_key_and_chain(ctx, base_state, tmp_path): """_setup_tls writes cert (+ CA chain) and private key to tls_cert_dir. @@ -112,8 +103,6 @@ def test_setup_tls_no_private_key_raises(ctx, base_state): TLS_CERT_DIR patch is needed. At __exit__ _reconcile also calls _setup_tls with the same mock, hits the same error, and sets BlockedStatus. """ - from exceptions import ConfigurationError - mock_cert = MagicMock() with ( From e3ba9af962f40af350d595b50b2de146242336a6 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 11:33:07 +0300 Subject: [PATCH 17/39] refactor(ha): holistic reconcile, fix security and test issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move all HA setup (SSH keys, authorized_keys, sync script, cronjob) into _reconcile so they're re-evaluated on every event, not just install - Replace os.system calls with subprocess.run and systemd.service_reload - Replace sed-based sshd_config mutation with pure Python - Guard _sync_authorized_keys against missing peer relation - Skip sync script/cronjob when secondary hostname is not yet known - Remove dead code (sync_smtp_aliases_target) - Remove _on_replicas_changed handler — folded into _reconcile - Fix _setup_ssh_keys to handle keygen failure gracefully - Rewrite unit tests per SKILL.md principles: assert on observable state (unit_status, opened_ports), comment every patch - Add HA patches to storage and TLS tests broken by holistic reconcile - Fix integration test copyright year and force-sync empty-maildir bug --- dovecot-charm/src/charm.py | 299 ++++++++++++--------- dovecot-charm/tests/integration/test_ha.py | 6 +- dovecot-charm/tests/unit/test_charm.py | 205 ++++++++++---- dovecot-charm/tests/unit/test_storage.py | 30 +++ dovecot-charm/tests/unit/test_tls.py | 15 ++ 5 files changed, 374 insertions(+), 181 deletions(-) diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index 7219f42..328f8cf 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -5,7 +5,6 @@ """Dovecot charm.""" import logging -import os import shutil import socket import subprocess # nosec @@ -44,6 +43,15 @@ logger = logging.getLogger(__name__) +# HA sync paths +SYNC_TO_SECONDARY_TARGET = "/usr/local/bin/sync-to-secondary.sh" +SYNC_TO_SECONDARY_CRONJOB_TARGET = "/etc/cron.d/sync-to-secondary" +SYNC_TO_SECONDARY_TEMPLATE = "sync-to-secondary.sh.tmpl" +SYNC_TO_SECONDARY_CRONJOB_TEMPLATE = "sync-to-secondary_cron.tmpl" + +SSHD_CONFIG = Path("/etc/ssh/sshd_config") +SSH_DIR = Path("/root/.ssh") + class DovecotCharm(CharmBase): """Dovecot IMAP/POP3 mail server charm.""" @@ -51,7 +59,7 @@ class DovecotCharm(CharmBase): def __init__(self, *args): super().__init__(*args) - # Events + # Events — every event except install goes through _reconcile. self.framework.observe(self.on.install, self._on_install) self.framework.observe(self.on.start, self._reconcile) self.framework.observe(self.on.config_changed, self._reconcile) @@ -59,7 +67,7 @@ def __init__(self, *args): self.framework.observe(self.on.clear_queue_action, self._on_clear_queue_action) self.framework.observe(self.on.mail_data_storage_attached, self._reconcile) self.framework.observe(self.on.mail_data_storage_detaching, self._reconcile) - self.framework.observe(self.on.replicas_relation_changed, self._on_replicas_changed) + self.framework.observe(self.on.replicas_relation_changed, self._reconcile) self.framework.observe(self.on.force_sync_action, self._on_force_sync) self.framework.observe( @@ -71,13 +79,6 @@ def __init__(self, *args): loader=jinja2.FileSystemLoader(TEMPLATES_DIR), autoescape=True ) - # Sync to secondary - self.sync_smtp_aliases_target = "/usr/local/bin/sync-smtp-aliases.sh" - self.sync_to_secondary_target = "/usr/local/bin/sync-to-secondary.sh" - self.sync_to_secondary_cronjob_target = "/etc/cron.d/sync-to-secondary" - self.sync_to_secondary_template = "sync-to-secondary.sh.tmpl" - self.sync_to_secondary_cronjob_template = "sync-to-secondary_cron.tmpl" - # TLS certificates integration self._tls = None mailname = self.config.get("mailname", "") @@ -123,6 +124,24 @@ def _is_primary(self): """Return True if this unit is the configured primary unit.""" return self.unit.name == self.config.get("primary-unit", "") + @property + def _secondary_hostname(self) -> typing.Optional[str]: + """Return the hostname/IP of the first remote peer unit, or None.""" + relation = self.model.get_relation(PEER_RELATION_NAME) + if not relation: + return None + + for unit in relation.units: + hostname = ( + relation.data[unit].get("hostname") + or relation.data[unit].get("private-address") + or relation.data[unit].get("ingress-address") + ) + if hostname: + return hostname + + return None + def _get_dovecot_config(self) -> DovecotConfig: """Craft the DovecotConfig from charm configuration and validate it. @@ -144,14 +163,19 @@ def _get_dovecot_config(self) -> DovecotConfig: logger.exception(f"Secret retrieval error: {exc}") raise ConfigurationError(str(exc)) from exc + # -- Event handlers ------------------------------------------------------- + def _on_install(self, event): - """Handle install event.""" + """Handle install event — install packages only, then reconcile.""" self.unit.status = MaintenanceStatus("Installing packages") self._install() self._reconcile(event) def _reconcile(self, event): - """Reconcile charm state for install, upgrade, config-changed, and storage events.""" + """Reconcile charm state for every event except install. + + Holistic handler: storage → TLS → dovecot → procmail → HA → ports. + """ self.unit.status = MaintenanceStatus("Configuring charm") try: dovecot_config = self._get_dovecot_config() @@ -170,21 +194,27 @@ def _reconcile(self, event): except ConfigurationError as e: self.unit.status = BlockedStatus(str(e)) return + # HA: SSH keys, authorized_keys, sync script + cronjob + self._setup_ssh_keys() + self._sync_authorized_keys() + if self._is_primary: + self._install_mail_sync_script() + self._setup_mail_sync_cronjob() self._open_ports() self.unit.status = ops.ActiveStatus() + # -- Installation --------------------------------------------------------- + def _install(self): - """Perform basic installation.""" + """Perform basic installation — packages and hostname only.""" self.unit.status = MaintenanceStatus("Installing required dependencies") apt.update() apt.add_package(REQUIRED_PACKAGES) shutil.copy(HOSTNAME_FILE, MAILNAME_FILE) - self._setup_ssh_keys() - if self._is_primary: - self._install_mail_sync_script() - self._setup_mail_sync_cronjob() self.unit.status = MaintenanceStatus("Charm installation done") + # -- Service configuration ------------------------------------------------ + def _open_ports(self): """Open mail ports.""" self.unit.open_port("tcp", 143) @@ -268,71 +298,84 @@ def _setup_procmail(self) -> None: logger.exception(f"Failed to configure postfix: {e}") raise ConfigurationError(f"Failed to configure postfix: {e.stderr}") from e - def _on_clear_queue_action(self, event): - """Handle the clear-queue action.""" - queue_to_clear = event.params.get("queue", "deferred") + def _setup_tls(self, dovecot_config: DovecotConfig) -> None: + """Write TLS cert+key to disk from the certificates relation. - if queue_to_clear not in ("deferred", "all"): - event.fail("Invalid queue parameter, must be 'deferred' or 'all'") - return - command = ["postsuper", "-d", "ALL"] + Called from _reconcile before _setup_dovecot so the cert files are + present when dovecot.conf is rendered and validated. - if queue_to_clear == "all": - logger.warning("Running clear-queue action: DELETING ALL mail from Postfix queue.") - else: - command.append("deferred") - logger.info("Running clear-queue action: Deleting deferred mail from Postfix queue.") + Raises: + ConfigurationError: If no TLS relation exists or the certificate + has not been issued yet. + """ + if not self._tls: + raise ConfigurationError( + "TLS certificates relation not available. " + "Integrate with a TLS provider using the 'certificates' relation." + ) - try: - # The command and arguments are fixed literals with no user-controlled input. - result = subprocess.run(command, check=True, capture_output=True, text=True) - event.set_results({"status": "success", "output": result.stdout}) - except subprocess.CalledProcessError as e: - logger.exception(f"Failed to clear Postfix queue: {e.stderr}") - event.fail(f"Failed to run postsuper: {e.stderr}") + cert_request = CertificateRequestAttributes( + common_name=dovecot_config.mailname, + sans_dns=frozenset([dovecot_config.mailname]), + ) + provider_cert, private_key = self._tls.get_assigned_certificate(cert_request) + if not provider_cert or not private_key: + raise ConfigurationError( + "TLS certificate not yet available from the certificates relation." + ) - @property - def _secondary_hostname(self): - """Return the hostname/IP of the secondary unit.""" - relation = self.model.get_relation("replicas") - if not relation: - return None + TLS_CERT_DIR.mkdir(parents=True, exist_ok=True) + cert_path = TLS_CERT_DIR / f"{dovecot_config.mailname}.pem" + key_path = TLS_CERT_DIR / f"{dovecot_config.mailname}.key" - for unit in relation.units: - return ( - relation.data[unit].get("hostname") - or relation.data[unit].get("private-address") - or relation.data[unit].get("ingress-address") - ) + cert_content = str(provider_cert.certificate) + if provider_cert.ca: + cert_content += "\n" + str(provider_cert.ca) + cert_path.write_text(cert_content) + cert_path.chmod(0o644) + logger.info(f"TLS certificate written to {cert_path}") - return None + key_path.write_text(str(private_key)) + key_path.chmod(0o600) + logger.info(f"TLS private key written to {key_path}") + + # -- HA / SSH key exchange ------------------------------------------------ def _setup_ssh_keys(self): - """Generate SSH key and share public key via peer relation.""" - ssh_dir = Path("/root/.ssh") - ssh_dir.mkdir(mode=0o700, exist_ok=True) - key_file = ssh_dir / "id_ed25519" + """Generate an SSH key pair if absent and publish the public key via the peer relation.""" + SSH_DIR.mkdir(mode=0o700, exist_ok=True) + key_file = SSH_DIR / "id_ed25519" if not key_file.exists(): - logger.warning("keyfile not there") - os.system(f'ssh-keygen -t ed25519 -N "" -f {key_file}') # noqa: S605 + subprocess.run( # noqa: S603 + ["ssh-keygen", "-t", "ed25519", "-N", "", "-f", str(key_file)], + check=True, + capture_output=True, + ) - pub_key = (ssh_dir / "id_ed25519.pub").read_text().strip() - relation = self.model.get_relation("replicas") + pub_key_file = SSH_DIR / "id_ed25519.pub" + if not pub_key_file.exists(): + logger.error("SSH public key file not found after key generation") + return + + pub_key = pub_key_file.read_text().strip() + relation = self.model.get_relation(PEER_RELATION_NAME) if relation: relation.data[self.unit]["public_key"] = pub_key relation.data[self.unit]["hostname"] = socket.gethostname() - config_file = ssh_dir / "config" + config_file = SSH_DIR / "config" if not config_file.exists(): config_file.write_text("Host *\n StrictHostKeyChecking no\n") config_file.chmod(0o600) - def _on_replicas_changed(self, event): - """Handle replicas relation changed — sync SSH authorized_keys.""" - authorized_keys = [] - relation = self.model.get_relation("replicas") + def _sync_authorized_keys(self): + """Collect public keys from all peer units and write authorized_keys.""" + relation = self.model.get_relation(PEER_RELATION_NAME) + if not relation: + return + authorized_keys = [] for unit in relation.units: pk = relation.data[unit].get("public_key") if pk: @@ -342,41 +385,92 @@ def _on_replicas_changed(self, event): if our_pk: authorized_keys.append(our_pk) - auth_file = Path("/root/.ssh/authorized_keys") - auth_file.write_text("\n".join(authorized_keys)) - auth_file.chmod(0o600) + if not authorized_keys: + return - self._ensure_root_ssh_configs() + auth_file = SSH_DIR / "authorized_keys" + auth_file.write_text("\n".join(authorized_keys) + "\n") + auth_file.chmod(0o600) - def _ensure_root_ssh_configs(self): - """Ensure PermitRootLogin is set in sshd_config.""" - cmd = "sed -i 's/^#*PermitRootLogin.*/PermitRootLogin prohibit-password/' /etc/ssh/sshd_config" - os.system(cmd) # noqa: S605 - os.system("systemctl restart ssh") # noqa: S605, S607 + self._ensure_root_ssh_login() + + def _ensure_root_ssh_login(self): + """Set PermitRootLogin to prohibit-password in sshd_config and reload sshd.""" + if SSHD_CONFIG.exists(): + content = SSHD_CONFIG.read_text() + new_content = "" + found = False + for line in content.splitlines(keepends=True): + stripped = line.lstrip("#").strip() + if stripped.startswith("PermitRootLogin"): + new_content += "PermitRootLogin prohibit-password\n" + found = True + else: + new_content += line + if not found: + new_content += "\nPermitRootLogin prohibit-password\n" + if new_content != content: + SSHD_CONFIG.write_text(new_content) + systemd.service_reload("ssh", restart_on_failure=True) def _install_mail_sync_script(self): - """Install mail pool synchronization script.""" + """Render and install the mail pool synchronization script. + + Skipped when the secondary hostname is not yet known (no remote peer). + """ + secondary = self._secondary_hostname + if not secondary: + logger.info("Secondary hostname not yet known; skipping sync script installation") + return + self.unit.status = MaintenanceStatus("Installing mail pool synchronization script") template_context = { - "secondary_hostname": self._secondary_hostname, + "secondary_hostname": secondary, "mail_root": MAIL_ROOT, } - template = self.jinja.get_template(self.sync_to_secondary_template) + template = self.jinja.get_template(SYNC_TO_SECONDARY_TEMPLATE) contents = template.render(template_context) - host.write_file(self.sync_to_secondary_target, contents, perms=0o755) - self.unit.status = MaintenanceStatus("Mail pool synchronization installed") + host.write_file(SYNC_TO_SECONDARY_TARGET, contents, perms=0o755) def _setup_mail_sync_cronjob(self): - """Set up mail pool synchronization cronjob.""" + """Set up the mail pool synchronization cronjob.""" + if not self._secondary_hostname: + logger.info("Secondary hostname not yet known; skipping cronjob setup") + return + self.unit.status = MaintenanceStatus("Setting up mail pool synchronization cronjob") template_context = { "schedule": self.config.get("sync-schedule", "*/30 * * * *"), } - template = self.jinja.get_template(self.sync_to_secondary_cronjob_template) + template = self.jinja.get_template(SYNC_TO_SECONDARY_CRONJOB_TEMPLATE) contents = template.render(template_context) - host.write_file(self.sync_to_secondary_cronjob_target, contents, perms=0o644) + host.write_file(SYNC_TO_SECONDARY_CRONJOB_TARGET, contents, perms=0o644) systemd.service_restart("cron") - self.unit.status = MaintenanceStatus("Mail pool synchronization cronjob has been set up") + + # -- Actions -------------------------------------------------------------- + + def _on_clear_queue_action(self, event): + """Handle the clear-queue action.""" + queue_to_clear = event.params.get("queue", "deferred") + + if queue_to_clear not in ("deferred", "all"): + event.fail("Invalid queue parameter, must be 'deferred' or 'all'") + return + command = ["postsuper", "-d", "ALL"] + + if queue_to_clear == "all": + logger.warning("Running clear-queue action: DELETING ALL mail from Postfix queue.") + else: + command.append("deferred") + logger.info("Running clear-queue action: Deleting deferred mail from Postfix queue.") + + try: + # The command and arguments are fixed literals with no user-controlled input. + result = subprocess.run(command, check=True, capture_output=True, text=True) + event.set_results({"status": "success", "output": result.stdout}) + except subprocess.CalledProcessError as e: + logger.exception(f"Failed to clear Postfix queue: {e.stderr}") + event.fail(f"Failed to run postsuper: {e.stderr}") def _on_force_sync(self, event): """Force synchronization with secondary unit.""" @@ -389,56 +483,15 @@ def _on_force_sync(self, event): return try: - cmd = [self.sync_to_secondary_target] + cmd = [SYNC_TO_SECONDARY_TARGET] logger.info(f"Running manual sync: {' '.join(cmd)}") - subprocess.run(cmd, check=True, capture_output=True, text=True) + subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603 event.set_results({"result": "Sync completed successfully"}) except subprocess.CalledProcessError as e: msg = f"Sync failed: {e.stderr}" logger.error(msg) event.fail(msg) - def _setup_tls(self, dovecot_config: DovecotConfig) -> None: - """Write TLS cert+key to disk from the certificates relation. - - Called from _reconcile before _setup_dovecot so the cert files are - present when dovecot.conf is rendered and validated. - - Raises: - ConfigurationError: If no TLS relation exists or the certificate - has not been issued yet. - """ - if not self._tls: - raise ConfigurationError( - "TLS certificates relation not available. " - "Integrate with a TLS provider using the 'certificates' relation." - ) - - cert_request = CertificateRequestAttributes( - common_name=dovecot_config.mailname, - sans_dns=frozenset([dovecot_config.mailname]), - ) - provider_cert, private_key = self._tls.get_assigned_certificate(cert_request) - if not provider_cert or not private_key: - raise ConfigurationError( - "TLS certificate not yet available from the certificates relation." - ) - - TLS_CERT_DIR.mkdir(parents=True, exist_ok=True) - cert_path = TLS_CERT_DIR / f"{dovecot_config.mailname}.pem" - key_path = TLS_CERT_DIR / f"{dovecot_config.mailname}.key" - - cert_content = str(provider_cert.certificate) - if provider_cert.ca: - cert_content += "\n" + str(provider_cert.ca) - cert_path.write_text(cert_content) - cert_path.chmod(0o644) - logger.info(f"TLS certificate written to {cert_path}") - - key_path.write_text(str(private_key)) - key_path.chmod(0o600) - logger.info(f"TLS private key written to {key_path}") - if __name__ == "__main__": # pragma: nocover main(DovecotCharm) diff --git a/dovecot-charm/tests/integration/test_ha.py b/dovecot-charm/tests/integration/test_ha.py index c20dccb..4fccd38 100644 --- a/dovecot-charm/tests/integration/test_ha.py +++ b/dovecot-charm/tests/integration/test_ha.py @@ -1,4 +1,4 @@ -# Copyright 2024 Canonical Ltd. +# Copyright 2026 Canonical Ltd. # See LICENSE file for licensing details. import logging @@ -77,6 +77,10 @@ def two_units_active(status): logging.info("Running force-sync on Primary...") + # Create a test Maildir so the sync script has something to sync. + # Without this, the script exits 1 because no Maildir directories exist. + juju.exec("mkdir -p /srv/mail/testuser/Maildir/{new,cur,tmp}", unit=primary) + task = juju.run(unit=primary, action="force-sync", wait=100) assert task.status == "completed" assert task.results["result"] == "Sync completed successfully" diff --git a/dovecot-charm/tests/unit/test_charm.py b/dovecot-charm/tests/unit/test_charm.py index e697d98..8fd31a1 100644 --- a/dovecot-charm/tests/unit/test_charm.py +++ b/dovecot-charm/tests/unit/test_charm.py @@ -1,5 +1,6 @@ # Copyright 2026 Canonical Ltd. # See LICENSE file for licensing details. +import contextlib import dataclasses from subprocess import CalledProcessError # nosec from unittest.mock import MagicMock, PropertyMock, patch @@ -12,42 +13,71 @@ from exceptions import ConfigurationError -def test_open_ports(ctx, base_state): +# --------------------------------------------------------------------------- +# Helpers — patches shared across many tests +# --------------------------------------------------------------------------- + +@contextlib.contextmanager +def reconcile_guards(): + """Guard all I/O in _reconcile so tests only exercise event wiring / status. + + Use when the test drives an event that triggers _reconcile but the test + is NOT about the logic inside these helpers (storage, TLS, dovecot, etc.). + """ with ( - # Guard real storage/TLS/dovecot operations so only port logic is exercised + # storage module talks to cryptsetup / mount — not under test patch("charm.ensure_storage_ready"), patch("charm.teardown_detaching_storage"), + # doveconf binary check — pretend it's installed patch("charm.shutil.which", return_value="/usr/bin/doveconf"), + # TLS writes cert/key files to disk — not under test patch("charm.DovecotCharm._setup_tls"), + # dovecot config rendering + validation + reload — not under test patch("charm.DovecotCharm._setup_dovecot"), + # procmail config rendering + postfix postconf — not under test patch("charm.DovecotCharm._setup_procmail"), + # SSH keygen + filesystem writes — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + # authorized_keys sync — not under test + patch("charm.DovecotCharm._sync_authorized_keys"), + # sync script rendering — not under test + patch("charm.DovecotCharm._install_mail_sync_script"), + # cronjob rendering + cron restart — not under test + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): - state_out = ctx.run(ctx.on.config_changed(), base_state) + yield - expected = {ops.testing.TCPPort(p) for p in [143, 993, 110, 995, 4190, 9900]} - assert state_out.opened_ports == expected +# --------------------------------------------------------------------------- +# Reconcile: status + ports +# --------------------------------------------------------------------------- -def test_configure_sets_active_on_success(ctx, base_state): - with ( - patch("charm.ensure_storage_ready"), - patch("charm.teardown_detaching_storage"), - patch("charm.shutil.which", return_value="/usr/bin/doveconf"), - patch("charm.DovecotCharm._setup_tls"), - patch("charm.DovecotCharm._setup_dovecot"), - patch("charm.DovecotCharm._setup_procmail"), - ): + +def test_reconcile_sets_active_on_success(ctx, base_state): + """Reconcile must reach ActiveStatus when all setup steps succeed.""" + with reconcile_guards(): state_out = ctx.run(ctx.on.config_changed(), base_state) assert isinstance(state_out.unit_status, ops.ActiveStatus) -def test_configure_blocks_when_dovecot_setup_fails(ctx, base_state): +def test_reconcile_opens_mail_ports(ctx, base_state): + """All required IMAP/POP3/Sieve/metrics ports must be opened.""" + with reconcile_guards(): + state_out = ctx.run(ctx.on.config_changed(), base_state) + + expected = {ops.testing.TCPPort(p) for p in [143, 993, 110, 995, 4190, 9900]} + assert state_out.opened_ports == expected + + +def test_reconcile_blocks_when_dovecot_setup_fails(ctx, base_state): + """Charm must be Blocked when _setup_dovecot raises ConfigurationError.""" with ( patch("charm.ensure_storage_ready"), patch("charm.teardown_detaching_storage"), patch("charm.shutil.which", return_value="/usr/bin/doveconf"), patch("charm.DovecotCharm._setup_tls"), + # _setup_dovecot raises — this is the condition under test patch( "charm.DovecotCharm._setup_dovecot", side_effect=ConfigurationError( @@ -62,13 +92,15 @@ def test_configure_blocks_when_dovecot_setup_fails(ctx, base_state): assert "Invalid Dovecot configuration" in state_out.unit_status.message -def test_configure_blocks_when_procmail_setup_fails(ctx, base_state): +def test_reconcile_blocks_when_procmail_setup_fails(ctx, base_state): + """Charm must be Blocked when _setup_procmail raises ConfigurationError.""" with ( patch("charm.ensure_storage_ready"), patch("charm.teardown_detaching_storage"), patch("charm.shutil.which", return_value="/usr/bin/doveconf"), patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), + # _setup_procmail raises — this is the condition under test patch( "charm.DovecotCharm._setup_procmail", side_effect=ConfigurationError("Failed to configure postfix: error"), @@ -80,12 +112,86 @@ def test_configure_blocks_when_procmail_setup_fails(ctx, base_state): assert "postfix" in state_out.unit_status.message -# --- Clear-queue action tests --- +# --------------------------------------------------------------------------- +# HA: _is_primary +# --------------------------------------------------------------------------- + + +def test_is_primary_true_when_unit_matches_config(ctx, base_state): + """_is_primary returns True when primary-unit config matches this unit.""" + # base_state has primary-unit=dovecot-charm/0; the ctx app_name gives unit dovecot-charm/0 + with reconcile_guards(), ctx(ctx.on.config_changed(), base_state) as mgr: + assert mgr.charm._is_primary is True + + +def test_is_primary_false_when_unit_differs(ctx, base_state): + """_is_primary returns False when primary-unit config doesn't match this unit. + + We access the charm inside the context manager before the event fires, + so no _reconcile I/O is reached — no patches needed for the HA methods. + Config validation is bypassed by patching _get_dovecot_config. + """ + state_in = dataclasses.replace( + base_state, config={**base_state.config, "primary-unit": "dovecot-charm/99"} + ) + with ( + # config validation rejects unknown units — bypass it since we're only testing _is_primary + patch("charm.DovecotCharm._get_dovecot_config"), + patch("charm.ensure_storage_ready"), + patch("charm.teardown_detaching_storage"), + patch("charm.shutil.which", return_value=None), + ctx(ctx.on.config_changed(), state_in) as mgr, + ): + assert mgr.charm._is_primary is False + + +# --------------------------------------------------------------------------- +# HA: reconcile calls sync script only on primary with known secondary +# --------------------------------------------------------------------------- + + +def test_reconcile_skips_sync_script_when_not_primary(ctx, base_state): + """When this unit is NOT primary, sync script and cronjob are not installed.""" + # Use a valid config but override _is_primary to False to bypass pydantic + # validation (which requires primary-unit to match an existing unit). + with ( + patch("charm.ensure_storage_ready"), + patch("charm.teardown_detaching_storage"), + patch("charm.shutil.which", return_value="/usr/bin/doveconf"), + patch("charm.DovecotCharm._setup_tls"), + patch("charm.DovecotCharm._setup_dovecot"), + patch("charm.DovecotCharm._setup_procmail"), + # ssh keygen — real subprocess not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + # authorized_keys sync — not under test + patch("charm.DovecotCharm._sync_authorized_keys"), + # Override _is_primary to simulate being a non-primary unit + patch("charm.DovecotCharm._is_primary", new_callable=PropertyMock, return_value=False), + # These should NOT be called — we verify via state not mocks + patch("charm.DovecotCharm._install_mail_sync_script") as mock_sync, + patch("charm.DovecotCharm._setup_mail_sync_cronjob") as mock_cron, + ): + state_out = ctx.run(ctx.on.config_changed(), base_state) + + # Charm still reaches Active even without sync scripts + assert isinstance(state_out.unit_status, ops.ActiveStatus) + # Secondary check: these should not have been called since unit is not primary + mock_sync.assert_not_called() + mock_cron.assert_not_called() + + +# --------------------------------------------------------------------------- +# Clear-queue action +# --------------------------------------------------------------------------- def test_clear_queue_deferred(ctx, base_state): + """clear-queue action with queue=deferred passes correct args to postsuper.""" mock_result = MagicMock(stdout="cleared") - with patch("charm.subprocess.run", return_value=mock_result) as mock_run: + with ( + # postsuper is the only subprocess call in this action path + patch("charm.subprocess.run", return_value=mock_result) as mock_run, + ): ctx.run( ctx.on.action("clear-queue", params={"queue": "deferred"}), base_state, @@ -100,8 +206,12 @@ def test_clear_queue_deferred(ctx, base_state): def test_clear_queue_all(ctx, base_state): + """clear-queue action with queue=all omits the deferred queue filter.""" mock_result = MagicMock(stdout="cleared") - with patch("charm.subprocess.run", return_value=mock_result) as mock_run: + with ( + # postsuper is the only subprocess call in this action path + patch("charm.subprocess.run", return_value=mock_result) as mock_run, + ): ctx.run( ctx.on.action("clear-queue", params={"queue": "all"}), base_state, @@ -116,7 +226,9 @@ def test_clear_queue_all(ctx, base_state): def test_clear_queue_failure(ctx, base_state): + """clear-queue action must fail when postsuper returns non-zero.""" with ( + # simulate postsuper failure patch( "charm.subprocess.run", side_effect=CalledProcessError(1, "postsuper", stderr="error msg"), @@ -130,45 +242,18 @@ def test_clear_queue_failure(ctx, base_state): assert "postsuper" in exc_info.value.message -def test_install_calls_all_setup_steps(ctx, base_state): - with ( - patch("charm.apt") as mock_apt, - patch("charm.shutil.copy") as mock_copy, - patch("charm.DovecotCharm._open_ports") as mock_open_ports, - patch("charm.DovecotCharm._setup_dovecot") as mock_dovecot, - patch("charm.DovecotCharm._setup_procmail") as mock_procmail, - patch("charm.DovecotCharm._setup_ssh_keys") as mock_setup_ssh, - patch("charm.DovecotCharm._install_mail_sync_script"), - patch("charm.DovecotCharm._setup_mail_sync_cronjob"), - ): - ctx.run(ctx.on.install(), base_state) - - mock_apt.update.assert_called_once() - mock_apt.add_package.assert_called_once() - mock_copy.assert_called_once_with("/etc/hostname", "/etc/mailname") - mock_open_ports.assert_called_once() - mock_dovecot.assert_called_once() - mock_procmail.assert_called_once() - mock_setup_ssh.assert_called_once() - - -def test_is_primary_true(ctx, base_state): - with patch("charm.DovecotCharm._install"), ctx(ctx.on.config_changed(), base_state) as mgr: - assert mgr.charm._is_primary is True - - -def test_is_primary_false(ctx, base_state): - state_in = dataclasses.replace( - base_state, config={**base_state.config, "primary-unit": "dovecot-charm/999"} - ) - with patch("charm.DovecotCharm._install"), ctx(ctx.on.config_changed(), state_in) as mgr: - assert mgr.charm._is_primary is False +# --------------------------------------------------------------------------- +# Force-sync action +# --------------------------------------------------------------------------- def test_force_sync_success(ctx, base_state): + """force-sync succeeds when this unit is primary and a secondary exists.""" mock_result = MagicMock(stdout="ok", stderr="") with ( + # sync script subprocess call — the action delegates to the shell script patch("charm.subprocess.run", return_value=mock_result), + # provide a secondary hostname so the action doesn't bail out patch.object( DovecotCharm, "_secondary_hostname", @@ -181,26 +266,32 @@ def test_force_sync_success(ctx, base_state): def test_force_sync_not_primary(ctx, base_state): - state_in = dataclasses.replace( - base_state, config={**base_state.config, "primary-unit": "dovecot-charm/999"} - ) - with pytest.raises(ops.testing.ActionFailed) as exc_info: - ctx.run(ctx.on.action("force-sync"), state_in) + """force-sync must fail when executed on a non-primary unit.""" + # Override _is_primary since pydantic rejects unknown unit names + with ( + patch("charm.DovecotCharm._is_primary", new_callable=PropertyMock, return_value=False), + pytest.raises(ops.testing.ActionFailed) as exc_info, + ): + ctx.run(ctx.on.action("force-sync"), base_state) assert "primary unit" in exc_info.value.message def test_force_sync_no_secondary(ctx, base_state): + """force-sync must fail when no secondary unit hostname is available.""" with pytest.raises(ops.testing.ActionFailed) as exc_info: ctx.run(ctx.on.action("force-sync"), base_state) assert "secondary" in exc_info.value.message def test_force_sync_subprocess_failure(ctx, base_state): + """force-sync must fail when the sync script exits non-zero.""" with ( + # sync script fails patch( "charm.subprocess.run", side_effect=CalledProcessError(1, "sync", stderr="fail"), ), + # provide secondary so the action reaches subprocess.run patch.object( DovecotCharm, "_secondary_hostname", diff --git a/dovecot-charm/tests/unit/test_storage.py b/dovecot-charm/tests/unit/test_storage.py index c204655..4306c52 100644 --- a/dovecot-charm/tests/unit/test_storage.py +++ b/dovecot-charm/tests/unit/test_storage.py @@ -29,6 +29,11 @@ def test_start_uses_saved_dev_path_when_model_error(ctx, base_state): patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O (ssh-keygen, authorized_keys, sync scripts) + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), patch("ops._main._Dispatcher.run_any_legacy_hook"), ): state_out = ctx.run(ctx.on.start(), state_in) @@ -72,6 +77,11 @@ def test_storage_attached_luks_auto_provisioning_disabled_mounted_is_active(ctx, patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_attached(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) @@ -105,6 +115,11 @@ def test_storage_attached_calls_setup_luks_with_key(ctx, base_state): patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_attached(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) @@ -125,6 +140,11 @@ def test_storage_attached_saves_dev_path(ctx, base_state): patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_attached(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) @@ -184,6 +204,11 @@ def test_storage_detaching_unmount_and_close(ctx, base_state): patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_detaching(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) @@ -219,6 +244,11 @@ def test_storage_detaching_luks_disabled_skips_close(ctx, base_state): patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_detaching(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) diff --git a/dovecot-charm/tests/unit/test_tls.py b/dovecot-charm/tests/unit/test_tls.py index f654cbc..58b180a 100644 --- a/dovecot-charm/tests/unit/test_tls.py +++ b/dovecot-charm/tests/unit/test_tls.py @@ -57,6 +57,11 @@ def test_setup_tls_writes_cert_key_and_chain(ctx, base_state, tmp_path): # Isolate from dovecot/procmail filesystem writes patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O (ssh-keygen, authorized_keys, sync scripts) + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ctx(ctx.on.config_changed(), base_state) as mgr, ): # Override the TLS library instance so get_assigned_certificate @@ -95,6 +100,11 @@ def test_setup_tls_no_ca_omits_chain(ctx, base_state, tmp_path): patch("charm.shutil.which", return_value="/usr/bin/doveconf"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ctx(ctx.on.config_changed(), base_state) as mgr, ): mgr.charm._tls = MagicMock() @@ -153,6 +163,11 @@ def test_certificate_available_event_triggers_reconcile(ctx, base_state, tmp_pat ), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): # Fire certificate_available via config_changed (same handler) state_out = ctx.run(ctx.on.config_changed(), base_state) From e7411958821af9a5f9e27d34457979a1a3f6040b Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 11:49:02 +0300 Subject: [PATCH 18/39] chore: fmt --- dovecot-charm/tests/unit/test_tls.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dovecot-charm/tests/unit/test_tls.py b/dovecot-charm/tests/unit/test_tls.py index fdbd869..b31756d 100644 --- a/dovecot-charm/tests/unit/test_tls.py +++ b/dovecot-charm/tests/unit/test_tls.py @@ -3,11 +3,12 @@ """Unit tests for TLS certificate integration.""" from unittest.mock import MagicMock, patch -from exceptions import ConfigurationError import ops import pytest +from exceptions import ConfigurationError + def test_no_tls_cert_yet_blocks(ctx, base_state): """Charm must be Blocked when the TLS relation has no certificate yet. From 2c94239cd5562bee1b30015a1d98e13500c868e4 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 12:59:44 +0300 Subject: [PATCH 19/39] fix(tls): correct charm name in TLS relation integration --- .opencode/plans/pr3-tls-review-fixes.md | 19 ------------------- dovecot-charm/tests/integration/conftest.py | 2 +- 2 files changed, 1 insertion(+), 20 deletions(-) delete mode 100644 .opencode/plans/pr3-tls-review-fixes.md diff --git a/.opencode/plans/pr3-tls-review-fixes.md b/.opencode/plans/pr3-tls-review-fixes.md deleted file mode 100644 index e086736..0000000 --- a/.opencode/plans/pr3-tls-review-fixes.md +++ /dev/null @@ -1,19 +0,0 @@ -# Plan: pr/3-tls Review Fixes - -**Branch:** `pr/3-tls` (rebased on `origin/main` after pr/2 merge) -**Problem:** The 9 commits have the original (pre-review) TLS implementation. All review fixes were lost during rebase (never committed). -**Goal:** Re-apply all fixes, commit, push. - -## Steps - -1. `constants.py` — add `TLS_CERT_DIR = Path("/etc/dovecot/private")` -2. `charm.py` — delete `_on_certificate_available`, create `_setup_tls`, wire `certificate_available` → `_reconcile`, use `TLS_CERT_DIR` constant, remove `tls_enabled` from template context -3. `dovecot.conf.tmpl` — always `ssl = required`, remove conditional and stale TODO -4. `test_charm.py` — delete 3 old TLS tests, replace `_install` patches with `_setup_tls`, add comments -5. `test_storage.py` — add `_setup_tls` patches to 6 tests reaching ActiveStatus -6. `tests/unit/test_tls.py` (new) — 6 tests following SKILL.md principles -7. `tests/integration/test_tls.py` — fix copyright, remove sleeps, fix stat quoting, add ssl=required assertion -8. `docs/explanation/charm-state-diagrams.md` — update for TLS states -9. Remove `dovecot-charm-state-diagrams.rst` if exists -10. Run `tox -e fmt,unit,lint` — all must pass -11. Commit and push diff --git a/dovecot-charm/tests/integration/conftest.py b/dovecot-charm/tests/integration/conftest.py index e1b3fc6..c96f64b 100644 --- a/dovecot-charm/tests/integration/conftest.py +++ b/dovecot-charm/tests/integration/conftest.py @@ -118,7 +118,7 @@ def dovecot_charm_manual_storage( try: logging.info("Adding TLS relation...") - juju.integrate(f"{dovecot_charm}:certificates", f"{tls_charm}:certificates") + juju.integrate(f"{charm_name}:certificates", f"{tls_charm}:certificates") except Exception: logging.info("TLS relation already there...") From 44630142eb18e9d533df3b62a27cd4b1403d3cb1 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 13:19:29 +0300 Subject: [PATCH 20/39] refactor(tls): replace inline status check with jubilant.all_active for clarity --- dovecot-charm/tests/integration/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dovecot-charm/tests/integration/conftest.py b/dovecot-charm/tests/integration/conftest.py index c96f64b..cbd7c40 100644 --- a/dovecot-charm/tests/integration/conftest.py +++ b/dovecot-charm/tests/integration/conftest.py @@ -82,7 +82,7 @@ def dovecot_charm( logging.info("TLS relation already there...") logging.info("Waiting for active status...") juju.wait( - lambda status: status.apps[APP_NAME].is_active and status.apps[tls_charm].is_active, + lambda status: jubilant.all_active(status, APP_NAME, tls_charm), timeout=10 * 60, ) return APP_NAME From 9e96d053556cdbedf40b08b745bd36c4569f63e5 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 13:33:55 +0300 Subject: [PATCH 21/39] fix(tls): close plaintext ports 143 and 110 since TLS is mandatory --- dovecot-charm/src/charm.py | 4 +--- dovecot-charm/tests/integration/test_config.py | 12 ------------ dovecot-charm/tests/unit/test_charm.py | 2 +- 3 files changed, 2 insertions(+), 16 deletions(-) diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index 0b5fddf..dcde3b9 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -166,10 +166,8 @@ def _install(self): self.unit.status = MaintenanceStatus("Charm installation done") def _open_ports(self): - """Open mail ports.""" - self.unit.open_port("tcp", 143) + """Open mail ports (TLS-only: plaintext 143/110 are not exposed).""" self.unit.open_port("tcp", 993) - self.unit.open_port("tcp", 110) self.unit.open_port("tcp", 995) self.unit.open_port("tcp", 4190) self.unit.open_port("tcp", 9900) diff --git a/dovecot-charm/tests/integration/test_config.py b/dovecot-charm/tests/integration/test_config.py index ad802cb..e76facb 100644 --- a/dovecot-charm/tests/integration/test_config.py +++ b/dovecot-charm/tests/integration/test_config.py @@ -10,24 +10,12 @@ def test_dovecot_protocol_responses(juju: jubilant.Juju, dovecot_charm: str): """Verify Dovecot responds to simple IMAP and POP3 commands.""" unit_name = f"{dovecot_charm}/0" - logging.info("Checking IMAP response on port 143...") - juju.exec( - "curl -fsS --max-time 10 --url imap://127.0.0.1:143 --request CAPABILITY | grep -q 'CAPABILITY'", - unit=unit_name, - ) - logging.info("Checking IMAPS response on port 993...") juju.exec( "curl -fsS --insecure --max-time 10 --url imaps://127.0.0.1:993 --request CAPABILITY | grep -q 'CAPABILITY'", unit=unit_name, ) - logging.info("Checking POP3 response on port 110...") - juju.exec( - "curl -fsS --max-time 10 --url pop3://127.0.0.1:110 --request CAPA | grep -Eq '(\\+OK|CAPA)'", - unit=unit_name, - ) - logging.info("Checking POP3S response on port 995...") juju.exec( "curl -fsS --insecure --max-time 10 --url pop3s://127.0.0.1:995 --request CAPA | grep -Eq '(\\+OK|CAPA)'", diff --git a/dovecot-charm/tests/unit/test_charm.py b/dovecot-charm/tests/unit/test_charm.py index 9691eec..a5f5676 100644 --- a/dovecot-charm/tests/unit/test_charm.py +++ b/dovecot-charm/tests/unit/test_charm.py @@ -22,7 +22,7 @@ def test_open_ports(ctx, base_state): ): state_out = ctx.run(ctx.on.config_changed(), base_state) - expected = {ops.testing.TCPPort(p) for p in [143, 993, 110, 995, 4190, 9900]} + expected = {ops.testing.TCPPort(p) for p in [993, 995, 4190, 9900]} assert state_out.opened_ports == expected From d807952145dec0fcaefd6ccf737af9ba4f43e671 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 10:10:58 +0300 Subject: [PATCH 22/39] feat(dovecot-charm): add HA support with SSH key exchange and force-sync action --- docs/release-notes/artifacts/pr-4-ha.yaml | 4 + docs/release-notes/index.rst | 1 + docs/release-notes/release-notes-0005.rst | 57 ++++++++ dovecot-charm/charmcraft.yaml | 2 + dovecot-charm/src/charm.py | 127 ++++++++++++++++++ .../templates/sync-to-secondary.sh.tmpl | 20 +++ .../templates/sync-to-secondary_cron.tmpl | 3 + dovecot-charm/tests/integration/test_ha.py | 89 ++++++++++++ dovecot-charm/tests/unit/test_charm.py | 87 +++++++++++- 9 files changed, 389 insertions(+), 1 deletion(-) create mode 100644 docs/release-notes/artifacts/pr-4-ha.yaml create mode 100644 docs/release-notes/release-notes-0005.rst create mode 100644 dovecot-charm/templates/sync-to-secondary.sh.tmpl create mode 100644 dovecot-charm/templates/sync-to-secondary_cron.tmpl create mode 100644 dovecot-charm/tests/integration/test_ha.py diff --git a/docs/release-notes/artifacts/pr-4-ha.yaml b/docs/release-notes/artifacts/pr-4-ha.yaml new file mode 100644 index 0000000..300a046 --- /dev/null +++ b/docs/release-notes/artifacts/pr-4-ha.yaml @@ -0,0 +1,4 @@ +name: pr-4-ha +type: major +summary: HA support with SSH key exchange and force-sync action +url: https://github.com/canonical/mailserver-operators/pull/4 diff --git a/docs/release-notes/index.rst b/docs/release-notes/index.rst index 23c635d..9ca49e0 100644 --- a/docs/release-notes/index.rst +++ b/docs/release-notes/index.rst @@ -35,3 +35,4 @@ Releases release-notes-0002 release-notes-0003 release-notes-0004 + release-notes-0005 diff --git a/docs/release-notes/release-notes-0005.rst b/docs/release-notes/release-notes-0005.rst new file mode 100644 index 0000000..9a5ac16 --- /dev/null +++ b/docs/release-notes/release-notes-0005.rst @@ -0,0 +1,57 @@ +.. _release_notes_release_notes_0005: + +Dovecot release notes – 2.3/edge +================================= + +These release notes cover new features and changes in Dovecot. + +Main features: + +* Added HA support with SSH key exchange and ``force-sync`` action. + +See our :ref:`Release policy and schedule `. + +Requirements and compatibility +------------------------------- + +The charm operates Dovecot 2.3. + +.. list-table:: + :header-rows: 1 + :widths: 50 50 + + * - Software + - Required version + * - Juju + - 3.x + * - Ubuntu + - 24.04 + +Updates +------- + +The following major and minor features were added in this release. + +HA support with SSH key exchange and force-sync action +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +High-availability support was added to the Dovecot charm. The charm now +exchanges SSH keys between primary and secondary units during installation, +enabling passwordless root SSH access required for mail pool synchronisation. +A new ``force-sync`` action was introduced, allowing operators to trigger an +immediate synchronisation of the mail pool from the primary unit to the +secondary unit on demand. + +Relevant links: + +* `PR `_ + +Bug fixes +--------- + +No bug fixes in this release. + +Known issues +------------ + +No known issues. diff --git a/dovecot-charm/charmcraft.yaml b/dovecot-charm/charmcraft.yaml index 7ef911e..616d42e 100644 --- a/dovecot-charm/charmcraft.yaml +++ b/dovecot-charm/charmcraft.yaml @@ -97,3 +97,5 @@ actions: description: The queue to clear (deferred or all). default: deferred enum: [deferred, all] + force-sync: + description: Manually trigger synchronization of mail to the secondary unit. diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index dcde3b9..55a8cce 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -5,7 +5,9 @@ """Dovecot charm.""" import logging +import os import shutil +import socket import subprocess # nosec import typing from pathlib import Path @@ -57,6 +59,8 @@ def __init__(self, *args): self.framework.observe(self.on.clear_queue_action, self._on_clear_queue_action) self.framework.observe(self.on.mail_data_storage_attached, self._reconcile) self.framework.observe(self.on.mail_data_storage_detaching, self._reconcile) + self.framework.observe(self.on.replicas_relation_changed, self._on_replicas_changed) + self.framework.observe(self.on.force_sync_action, self._on_force_sync) self.framework.observe( self.on[PEER_RELATION_NAME].relation_created, @@ -67,6 +71,13 @@ def __init__(self, *args): loader=jinja2.FileSystemLoader(TEMPLATES_DIR), autoescape=True ) + # Sync to secondary + self.sync_smtp_aliases_target = "/usr/local/bin/sync-smtp-aliases.sh" + self.sync_to_secondary_target = "/usr/local/bin/sync-to-secondary.sh" + self.sync_to_secondary_cronjob_target = "/etc/cron.d/sync-to-secondary" + self.sync_to_secondary_template = "sync-to-secondary.sh.tmpl" + self.sync_to_secondary_cronjob_template = "sync-to-secondary_cron.tmpl" + # TLS certificates integration self._tls = None mailname = self.config.get("mailname", "") @@ -107,6 +118,11 @@ def _on_peer_relation_created(self, event): relation_data = event.relation.data[self.unit] relation_data["unit-name"] = self.unit.name + @property + def _is_primary(self): + """Return True if this unit is the configured primary unit.""" + return self.unit.name == self.config.get("primary-unit", "") + def _get_dovecot_config(self) -> DovecotConfig: """Craft the DovecotConfig from charm configuration and validate it. @@ -163,6 +179,10 @@ def _install(self): apt.update() apt.add_package(REQUIRED_PACKAGES) shutil.copy(HOSTNAME_FILE, MAILNAME_FILE) + self._setup_ssh_keys() + if self._is_primary: + self._install_mail_sync_script() + self._setup_mail_sync_cronjob() self.unit.status = MaintenanceStatus("Charm installation done") def _open_ports(self): @@ -269,6 +289,113 @@ def _on_clear_queue_action(self, event): logger.exception(f"Failed to clear Postfix queue: {e.stderr}") event.fail(f"Failed to run postsuper: {e.stderr}") + @property + def _secondary_hostname(self): + """Return the hostname/IP of the secondary unit.""" + relation = self.model.get_relation("replicas") + if not relation: + return None + + for unit in relation.units: + return ( + relation.data[unit].get("hostname") + or relation.data[unit].get("private-address") + or relation.data[unit].get("ingress-address") + ) + + return None + + def _setup_ssh_keys(self): + """Generate SSH key and share public key via peer relation.""" + ssh_dir = Path("/root/.ssh") + ssh_dir.mkdir(mode=0o700, exist_ok=True) + key_file = ssh_dir / "id_ed25519" + + if not key_file.exists(): + logger.warning("keyfile not there") + os.system(f'ssh-keygen -t ed25519 -N "" -f {key_file}') # noqa: S605 + + pub_key = (ssh_dir / "id_ed25519.pub").read_text().strip() + relation = self.model.get_relation("replicas") + if relation: + relation.data[self.unit]["public_key"] = pub_key + relation.data[self.unit]["hostname"] = socket.gethostname() + + config_file = ssh_dir / "config" + if not config_file.exists(): + config_file.write_text("Host *\n StrictHostKeyChecking no\n") + config_file.chmod(0o600) + + def _on_replicas_changed(self, event): + """Handle replicas relation changed — sync SSH authorized_keys.""" + authorized_keys = [] + relation = self.model.get_relation("replicas") + + for unit in relation.units: + pk = relation.data[unit].get("public_key") + if pk: + authorized_keys.append(pk) + + our_pk = relation.data[self.unit].get("public_key") + if our_pk: + authorized_keys.append(our_pk) + + auth_file = Path("/root/.ssh/authorized_keys") + auth_file.write_text("\n".join(authorized_keys)) + auth_file.chmod(0o600) + + self._ensure_root_ssh_configs() + + def _ensure_root_ssh_configs(self): + """Ensure PermitRootLogin is set in sshd_config.""" + cmd = "sed -i 's/^#*PermitRootLogin.*/PermitRootLogin prohibit-password/' /etc/ssh/sshd_config" + os.system(cmd) # noqa: S605 + os.system("systemctl restart ssh") # noqa: S605, S607 + + def _install_mail_sync_script(self): + """Install mail pool synchronization script.""" + self.unit.status = MaintenanceStatus("Installing mail pool synchronization script") + template_context = { + "secondary_hostname": self._secondary_hostname, + "mail_root": MAIL_ROOT, + } + template = self.jinja.get_template(self.sync_to_secondary_template) + contents = template.render(template_context) + host.write_file(self.sync_to_secondary_target, contents, perms=0o755) + self.unit.status = MaintenanceStatus("Mail pool synchronization installed") + + def _setup_mail_sync_cronjob(self): + """Set up mail pool synchronization cronjob.""" + self.unit.status = MaintenanceStatus("Setting up mail pool synchronization cronjob") + template_context = { + "schedule": self.config.get("sync-schedule", "*/30 * * * *"), + } + template = self.jinja.get_template(self.sync_to_secondary_cronjob_template) + contents = template.render(template_context) + host.write_file(self.sync_to_secondary_cronjob_target, contents, perms=0o644) + systemd.service_restart("cron") + self.unit.status = MaintenanceStatus("Mail pool synchronization cronjob has been set up") + + def _on_force_sync(self, event): + """Force synchronization with secondary unit.""" + if not self._is_primary: + event.fail("This action can only be run on the primary unit.") + return + + if not self._secondary_hostname: + event.fail("No secondary unit found to sync to.") + return + + try: + cmd = [self.sync_to_secondary_target] + logger.info(f"Running manual sync: {' '.join(cmd)}") + subprocess.run(cmd, check=True, capture_output=True, text=True) + event.set_results({"result": "Sync completed successfully"}) + except subprocess.CalledProcessError as e: + msg = f"Sync failed: {e.stderr}" + logger.error(msg) + event.fail(msg) + def _setup_tls(self, dovecot_config: DovecotConfig) -> None: """Write TLS cert+key to disk from the certificates relation. diff --git a/dovecot-charm/templates/sync-to-secondary.sh.tmpl b/dovecot-charm/templates/sync-to-secondary.sh.tmpl new file mode 100644 index 0000000..126ad02 --- /dev/null +++ b/dovecot-charm/templates/sync-to-secondary.sh.tmpl @@ -0,0 +1,20 @@ +#!/bin/bash + +set -eu + +# Sync using doveadm (dsync) for users that have a Maildir. +# Avoids syncing system accounts without mailboxes. +remote="remote:root@{{ secondary_hostname }}" +found=0 +for user_dir in "{{ mail_root }}"/*; do + if [ -d "$user_dir/Maildir" ]; then + user="$(basename "$user_dir")" + doveadm backup -u "$user" "$remote" + found=1 + fi +done +if [ "$found" -eq 0 ]; then + echo "No Maildir found under {{ mail_root }}; nothing to sync." >&2 + exit 1 +fi +touch {{ mail_root }}/.last-dsync diff --git a/dovecot-charm/templates/sync-to-secondary_cron.tmpl b/dovecot-charm/templates/sync-to-secondary_cron.tmpl new file mode 100644 index 0000000..d3101cb --- /dev/null +++ b/dovecot-charm/templates/sync-to-secondary_cron.tmpl @@ -0,0 +1,3 @@ +{{ schedule }} root /usr/local/bin/sync-to-secondary.sh >> /var/log/sync-to-secondary.log 2>&1 + +# End of file diff --git a/dovecot-charm/tests/integration/test_ha.py b/dovecot-charm/tests/integration/test_ha.py new file mode 100644 index 0000000..c20dccb --- /dev/null +++ b/dovecot-charm/tests/integration/test_ha.py @@ -0,0 +1,89 @@ +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +import logging +from typing import cast + +import jubilant +import pytest + + +def _get_unit_hostname(status, app_name, unit_name): + """Helper to get unit hostname from status.""" + try: + machine = status.apps[app_name].units[unit_name].machine + return status.machines[machine].hostname + except KeyError: + logging.error(f"Unit {unit_name} not found in status.") + return None + + +@pytest.mark.timeout(1800) +def test_ha_failover(juju, dovecot_charm): + status = juju.status() + if len(status.apps[dovecot_charm].units) < 2: + logging.info("Adding the second unit...") + juju.add_unit(dovecot_charm, num_units=1) + + def two_units_active(status): + app = status.apps.get(dovecot_charm) + if not app: + return False + if len(app.units) < 2: + return False + return jubilant.all_active(status) + + logging.info("Waiting for 2 units to be active...") + juju.wait(two_units_active, timeout=600) + + status = juju.status() + units = list(status.apps[dovecot_charm].units.keys()) + units.sort(key=lambda x: int(x.split("/")[-1])) + + primary = units[0] + secondary = units[1] + + logging.info(f"Primary: {primary}, Secondary: {secondary}") + + juju.config(dovecot_charm, {"primary-unit": primary}) + juju.wait(jubilant.all_active, timeout=300) + + logging.info("Verifying SSH key exchange...") + + cmd = "cat /root/.ssh/authorized_keys | wc -l" + + result_primary = juju.exec(cmd, unit=primary) + logging.info(f"Primary authorized_keys count: {result_primary.stdout.strip()}") + assert int(result_primary.stdout.strip()) >= 1 + + result_secondary = juju.exec(cmd, unit=secondary) + logging.info(f"Secondary authorized_keys count: {result_secondary.stdout.strip()}") + assert int(result_secondary.stdout.strip()) >= 1 + + logging.info("Verifying sync script on Primary...") + + status = juju.status() + secondary_hostname = _get_unit_hostname(status, dovecot_charm, secondary) + logging.info(f"Secondary hostname: {secondary_hostname}") + + script_path = "/usr/local/bin/sync-to-secondary.sh" + cmd = f"cat {script_path}" + script_content = juju.exec(cmd, unit=primary).stdout + + logging.info(f"Sync script content on Primary:\n{script_content}") + assert secondary_hostname in script_content, ( + "Secondary hostname not found in sync script on Primary" + ) + + logging.info("Running force-sync on Primary...") + + task = juju.run(unit=primary, action="force-sync", wait=100) + assert task.status == "completed" + assert task.results["result"] == "Sync completed successfully" + + with pytest.raises(jubilant.TaskError) as exc_info: + juju.run(unit=secondary, action="force-sync", wait=100) + assert cast(jubilant.TaskError, exc_info.value).task.status == "failed" + logging.info("force-sync on Secondary correctly failed.") + + logging.info("HA Failover test passed.") diff --git a/dovecot-charm/tests/unit/test_charm.py b/dovecot-charm/tests/unit/test_charm.py index a5f5676..9693a42 100644 --- a/dovecot-charm/tests/unit/test_charm.py +++ b/dovecot-charm/tests/unit/test_charm.py @@ -1,12 +1,14 @@ # Copyright 2026 Canonical Ltd. # See LICENSE file for licensing details. +import dataclasses from subprocess import CalledProcessError # nosec -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import ops import ops.testing import pytest +from charm import DovecotCharm from exceptions import ConfigurationError @@ -126,3 +128,86 @@ def test_clear_queue_failure(ctx, base_state): base_state, ) assert "postsuper" in exc_info.value.message + + +def test_install_calls_all_setup_steps(ctx, base_state): + with ( + patch("charm.apt") as mock_apt, + patch("charm.shutil.copy") as mock_copy, + patch("charm.DovecotCharm._open_ports") as mock_open_ports, + patch("charm.DovecotCharm._setup_dovecot") as mock_dovecot, + patch("charm.DovecotCharm._setup_procmail") as mock_procmail, + patch("charm.DovecotCharm._setup_ssh_keys") as mock_setup_ssh, + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), + ): + ctx.run(ctx.on.install(), base_state) + + mock_apt.update.assert_called_once() + mock_apt.add_package.assert_called_once() + mock_copy.assert_called_once_with("/etc/hostname", "/etc/mailname") + mock_open_ports.assert_called_once() + mock_dovecot.assert_called_once() + mock_procmail.assert_called_once() + mock_setup_ssh.assert_called_once() + + +def test_is_primary_true(ctx, base_state): + with patch("charm.DovecotCharm._install"), ctx(ctx.on.config_changed(), base_state) as mgr: + assert mgr.charm._is_primary is True + + +def test_is_primary_false(ctx, base_state): + state_in = dataclasses.replace( + base_state, config={**base_state.config, "primary-unit": "dovecot-charm/999"} + ) + with patch("charm.DovecotCharm._install"), ctx(ctx.on.config_changed(), state_in) as mgr: + assert mgr.charm._is_primary is False + + +def test_force_sync_success(ctx, base_state): + mock_result = MagicMock(stdout="ok", stderr="") + with ( + patch("charm.subprocess.run", return_value=mock_result), + patch.object( + DovecotCharm, + "_secondary_hostname", + new_callable=PropertyMock, + return_value="10.0.0.2", + ), + ): + ctx.run(ctx.on.action("force-sync"), base_state) + assert ctx.action_results == {"result": "Sync completed successfully"} + + +def test_force_sync_not_primary(ctx, base_state): + state_in = dataclasses.replace( + base_state, config={**base_state.config, "primary-unit": "dovecot-charm/999"} + ) + with pytest.raises(ops.testing.ActionFailed) as exc_info: + ctx.run(ctx.on.action("force-sync"), state_in) + assert "primary unit" in exc_info.value.message + + +def test_force_sync_no_secondary(ctx, base_state): + with pytest.raises(ops.testing.ActionFailed) as exc_info: + ctx.run(ctx.on.action("force-sync"), base_state) + assert "secondary" in exc_info.value.message + + +def test_force_sync_subprocess_failure(ctx, base_state): + with ( + patch( + "charm.subprocess.run", + side_effect=CalledProcessError(1, "sync", stderr="fail"), + ), + patch.object( + DovecotCharm, + "_secondary_hostname", + new_callable=PropertyMock, + return_value="10.0.0.2", + ), + pytest.raises(ops.testing.ActionFailed) as exc_info, + ): + ctx.run(ctx.on.action("force-sync"), base_state) + assert "fail" in exc_info.value.message From e1c91ee4e350490530b11cb1363ff78ecd14a749 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 10:10:59 +0300 Subject: [PATCH 23/39] docs: add release notes for pr/4-ha From d3c292f320a46b23ae8a79a0686d9aba94a9b67b Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 11:33:07 +0300 Subject: [PATCH 24/39] refactor(ha): holistic reconcile, fix security and test issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move all HA setup (SSH keys, authorized_keys, sync script, cronjob) into _reconcile so they're re-evaluated on every event, not just install - Replace os.system calls with subprocess.run and systemd.service_reload - Replace sed-based sshd_config mutation with pure Python - Guard _sync_authorized_keys against missing peer relation - Skip sync script/cronjob when secondary hostname is not yet known - Remove dead code (sync_smtp_aliases_target) - Remove _on_replicas_changed handler — folded into _reconcile - Fix _setup_ssh_keys to handle keygen failure gracefully - Rewrite unit tests per SKILL.md principles: assert on observable state (unit_status, opened_ports), comment every patch - Add HA patches to storage and TLS tests broken by holistic reconcile - Fix integration test copyright year and force-sync empty-maildir bug --- dovecot-charm/src/charm.py | 299 ++++++++++++--------- dovecot-charm/tests/integration/test_ha.py | 6 +- dovecot-charm/tests/unit/test_charm.py | 205 ++++++++++---- dovecot-charm/tests/unit/test_storage.py | 30 +++ dovecot-charm/tests/unit/test_tls.py | 15 ++ 5 files changed, 374 insertions(+), 181 deletions(-) diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index 55a8cce..bc719fe 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -5,7 +5,6 @@ """Dovecot charm.""" import logging -import os import shutil import socket import subprocess # nosec @@ -44,6 +43,15 @@ logger = logging.getLogger(__name__) +# HA sync paths +SYNC_TO_SECONDARY_TARGET = "/usr/local/bin/sync-to-secondary.sh" +SYNC_TO_SECONDARY_CRONJOB_TARGET = "/etc/cron.d/sync-to-secondary" +SYNC_TO_SECONDARY_TEMPLATE = "sync-to-secondary.sh.tmpl" +SYNC_TO_SECONDARY_CRONJOB_TEMPLATE = "sync-to-secondary_cron.tmpl" + +SSHD_CONFIG = Path("/etc/ssh/sshd_config") +SSH_DIR = Path("/root/.ssh") + class DovecotCharm(CharmBase): """Dovecot IMAP/POP3 mail server charm.""" @@ -51,7 +59,7 @@ class DovecotCharm(CharmBase): def __init__(self, *args): super().__init__(*args) - # Events + # Events — every event except install goes through _reconcile. self.framework.observe(self.on.install, self._on_install) self.framework.observe(self.on.start, self._reconcile) self.framework.observe(self.on.config_changed, self._reconcile) @@ -59,7 +67,7 @@ def __init__(self, *args): self.framework.observe(self.on.clear_queue_action, self._on_clear_queue_action) self.framework.observe(self.on.mail_data_storage_attached, self._reconcile) self.framework.observe(self.on.mail_data_storage_detaching, self._reconcile) - self.framework.observe(self.on.replicas_relation_changed, self._on_replicas_changed) + self.framework.observe(self.on.replicas_relation_changed, self._reconcile) self.framework.observe(self.on.force_sync_action, self._on_force_sync) self.framework.observe( @@ -71,13 +79,6 @@ def __init__(self, *args): loader=jinja2.FileSystemLoader(TEMPLATES_DIR), autoescape=True ) - # Sync to secondary - self.sync_smtp_aliases_target = "/usr/local/bin/sync-smtp-aliases.sh" - self.sync_to_secondary_target = "/usr/local/bin/sync-to-secondary.sh" - self.sync_to_secondary_cronjob_target = "/etc/cron.d/sync-to-secondary" - self.sync_to_secondary_template = "sync-to-secondary.sh.tmpl" - self.sync_to_secondary_cronjob_template = "sync-to-secondary_cron.tmpl" - # TLS certificates integration self._tls = None mailname = self.config.get("mailname", "") @@ -123,6 +124,24 @@ def _is_primary(self): """Return True if this unit is the configured primary unit.""" return self.unit.name == self.config.get("primary-unit", "") + @property + def _secondary_hostname(self) -> typing.Optional[str]: + """Return the hostname/IP of the first remote peer unit, or None.""" + relation = self.model.get_relation(PEER_RELATION_NAME) + if not relation: + return None + + for unit in relation.units: + hostname = ( + relation.data[unit].get("hostname") + or relation.data[unit].get("private-address") + or relation.data[unit].get("ingress-address") + ) + if hostname: + return hostname + + return None + def _get_dovecot_config(self) -> DovecotConfig: """Craft the DovecotConfig from charm configuration and validate it. @@ -144,14 +163,19 @@ def _get_dovecot_config(self) -> DovecotConfig: logger.exception(f"Secret retrieval error: {exc}") raise ConfigurationError(str(exc)) from exc + # -- Event handlers ------------------------------------------------------- + def _on_install(self, event): - """Handle install event.""" + """Handle install event — install packages only, then reconcile.""" self.unit.status = MaintenanceStatus("Installing packages") self._install() self._reconcile(event) def _reconcile(self, event): - """Reconcile charm state for install, upgrade, config-changed, and storage events.""" + """Reconcile charm state for every event except install. + + Holistic handler: storage → TLS → dovecot → procmail → HA → ports. + """ self.unit.status = MaintenanceStatus("Configuring charm") try: dovecot_config = self._get_dovecot_config() @@ -170,21 +194,27 @@ def _reconcile(self, event): except ConfigurationError as e: self.unit.status = BlockedStatus(str(e)) return + # HA: SSH keys, authorized_keys, sync script + cronjob + self._setup_ssh_keys() + self._sync_authorized_keys() + if self._is_primary: + self._install_mail_sync_script() + self._setup_mail_sync_cronjob() self._open_ports() self.unit.status = ops.ActiveStatus() + # -- Installation --------------------------------------------------------- + def _install(self): - """Perform basic installation.""" + """Perform basic installation — packages and hostname only.""" self.unit.status = MaintenanceStatus("Installing required dependencies") apt.update() apt.add_package(REQUIRED_PACKAGES) shutil.copy(HOSTNAME_FILE, MAILNAME_FILE) - self._setup_ssh_keys() - if self._is_primary: - self._install_mail_sync_script() - self._setup_mail_sync_cronjob() self.unit.status = MaintenanceStatus("Charm installation done") + # -- Service configuration ------------------------------------------------ + def _open_ports(self): """Open mail ports (TLS-only: plaintext 143/110 are not exposed).""" self.unit.open_port("tcp", 993) @@ -266,71 +296,84 @@ def _setup_procmail(self) -> None: logger.exception(f"Failed to configure postfix: {e}") raise ConfigurationError(f"Failed to configure postfix: {e.stderr}") from e - def _on_clear_queue_action(self, event): - """Handle the clear-queue action.""" - queue_to_clear = event.params.get("queue", "deferred") + def _setup_tls(self, dovecot_config: DovecotConfig) -> None: + """Write TLS cert+key to disk from the certificates relation. - if queue_to_clear not in ("deferred", "all"): - event.fail("Invalid queue parameter, must be 'deferred' or 'all'") - return - command = ["postsuper", "-d", "ALL"] + Called from _reconcile before _setup_dovecot so the cert files are + present when dovecot.conf is rendered and validated. - if queue_to_clear == "all": - logger.warning("Running clear-queue action: DELETING ALL mail from Postfix queue.") - else: - command.append("deferred") - logger.info("Running clear-queue action: Deleting deferred mail from Postfix queue.") + Raises: + ConfigurationError: If no TLS relation exists or the certificate + has not been issued yet. + """ + if not self._tls: + raise ConfigurationError( + "TLS certificates relation not available. " + "Integrate with a TLS provider using the 'certificates' relation." + ) - try: - # The command and arguments are fixed literals with no user-controlled input. - result = subprocess.run(command, check=True, capture_output=True, text=True) - event.set_results({"status": "success", "output": result.stdout}) - except subprocess.CalledProcessError as e: - logger.exception(f"Failed to clear Postfix queue: {e.stderr}") - event.fail(f"Failed to run postsuper: {e.stderr}") + cert_request = CertificateRequestAttributes( + common_name=dovecot_config.mailname, + sans_dns=frozenset([dovecot_config.mailname]), + ) + provider_cert, private_key = self._tls.get_assigned_certificate(cert_request) + if not provider_cert or not private_key: + raise ConfigurationError( + "TLS certificate not yet available from the certificates relation." + ) - @property - def _secondary_hostname(self): - """Return the hostname/IP of the secondary unit.""" - relation = self.model.get_relation("replicas") - if not relation: - return None + TLS_CERT_DIR.mkdir(parents=True, exist_ok=True) + cert_path = TLS_CERT_DIR / f"{dovecot_config.mailname}.pem" + key_path = TLS_CERT_DIR / f"{dovecot_config.mailname}.key" - for unit in relation.units: - return ( - relation.data[unit].get("hostname") - or relation.data[unit].get("private-address") - or relation.data[unit].get("ingress-address") - ) + cert_content = str(provider_cert.certificate) + if provider_cert.ca: + cert_content += "\n" + str(provider_cert.ca) + cert_path.write_text(cert_content) + cert_path.chmod(0o644) + logger.info(f"TLS certificate written to {cert_path}") - return None + key_path.write_text(str(private_key)) + key_path.chmod(0o600) + logger.info(f"TLS private key written to {key_path}") + + # -- HA / SSH key exchange ------------------------------------------------ def _setup_ssh_keys(self): - """Generate SSH key and share public key via peer relation.""" - ssh_dir = Path("/root/.ssh") - ssh_dir.mkdir(mode=0o700, exist_ok=True) - key_file = ssh_dir / "id_ed25519" + """Generate an SSH key pair if absent and publish the public key via the peer relation.""" + SSH_DIR.mkdir(mode=0o700, exist_ok=True) + key_file = SSH_DIR / "id_ed25519" if not key_file.exists(): - logger.warning("keyfile not there") - os.system(f'ssh-keygen -t ed25519 -N "" -f {key_file}') # noqa: S605 + subprocess.run( # noqa: S603 + ["ssh-keygen", "-t", "ed25519", "-N", "", "-f", str(key_file)], + check=True, + capture_output=True, + ) - pub_key = (ssh_dir / "id_ed25519.pub").read_text().strip() - relation = self.model.get_relation("replicas") + pub_key_file = SSH_DIR / "id_ed25519.pub" + if not pub_key_file.exists(): + logger.error("SSH public key file not found after key generation") + return + + pub_key = pub_key_file.read_text().strip() + relation = self.model.get_relation(PEER_RELATION_NAME) if relation: relation.data[self.unit]["public_key"] = pub_key relation.data[self.unit]["hostname"] = socket.gethostname() - config_file = ssh_dir / "config" + config_file = SSH_DIR / "config" if not config_file.exists(): config_file.write_text("Host *\n StrictHostKeyChecking no\n") config_file.chmod(0o600) - def _on_replicas_changed(self, event): - """Handle replicas relation changed — sync SSH authorized_keys.""" - authorized_keys = [] - relation = self.model.get_relation("replicas") + def _sync_authorized_keys(self): + """Collect public keys from all peer units and write authorized_keys.""" + relation = self.model.get_relation(PEER_RELATION_NAME) + if not relation: + return + authorized_keys = [] for unit in relation.units: pk = relation.data[unit].get("public_key") if pk: @@ -340,41 +383,92 @@ def _on_replicas_changed(self, event): if our_pk: authorized_keys.append(our_pk) - auth_file = Path("/root/.ssh/authorized_keys") - auth_file.write_text("\n".join(authorized_keys)) - auth_file.chmod(0o600) + if not authorized_keys: + return - self._ensure_root_ssh_configs() + auth_file = SSH_DIR / "authorized_keys" + auth_file.write_text("\n".join(authorized_keys) + "\n") + auth_file.chmod(0o600) - def _ensure_root_ssh_configs(self): - """Ensure PermitRootLogin is set in sshd_config.""" - cmd = "sed -i 's/^#*PermitRootLogin.*/PermitRootLogin prohibit-password/' /etc/ssh/sshd_config" - os.system(cmd) # noqa: S605 - os.system("systemctl restart ssh") # noqa: S605, S607 + self._ensure_root_ssh_login() + + def _ensure_root_ssh_login(self): + """Set PermitRootLogin to prohibit-password in sshd_config and reload sshd.""" + if SSHD_CONFIG.exists(): + content = SSHD_CONFIG.read_text() + new_content = "" + found = False + for line in content.splitlines(keepends=True): + stripped = line.lstrip("#").strip() + if stripped.startswith("PermitRootLogin"): + new_content += "PermitRootLogin prohibit-password\n" + found = True + else: + new_content += line + if not found: + new_content += "\nPermitRootLogin prohibit-password\n" + if new_content != content: + SSHD_CONFIG.write_text(new_content) + systemd.service_reload("ssh", restart_on_failure=True) def _install_mail_sync_script(self): - """Install mail pool synchronization script.""" + """Render and install the mail pool synchronization script. + + Skipped when the secondary hostname is not yet known (no remote peer). + """ + secondary = self._secondary_hostname + if not secondary: + logger.info("Secondary hostname not yet known; skipping sync script installation") + return + self.unit.status = MaintenanceStatus("Installing mail pool synchronization script") template_context = { - "secondary_hostname": self._secondary_hostname, + "secondary_hostname": secondary, "mail_root": MAIL_ROOT, } - template = self.jinja.get_template(self.sync_to_secondary_template) + template = self.jinja.get_template(SYNC_TO_SECONDARY_TEMPLATE) contents = template.render(template_context) - host.write_file(self.sync_to_secondary_target, contents, perms=0o755) - self.unit.status = MaintenanceStatus("Mail pool synchronization installed") + host.write_file(SYNC_TO_SECONDARY_TARGET, contents, perms=0o755) def _setup_mail_sync_cronjob(self): - """Set up mail pool synchronization cronjob.""" + """Set up the mail pool synchronization cronjob.""" + if not self._secondary_hostname: + logger.info("Secondary hostname not yet known; skipping cronjob setup") + return + self.unit.status = MaintenanceStatus("Setting up mail pool synchronization cronjob") template_context = { "schedule": self.config.get("sync-schedule", "*/30 * * * *"), } - template = self.jinja.get_template(self.sync_to_secondary_cronjob_template) + template = self.jinja.get_template(SYNC_TO_SECONDARY_CRONJOB_TEMPLATE) contents = template.render(template_context) - host.write_file(self.sync_to_secondary_cronjob_target, contents, perms=0o644) + host.write_file(SYNC_TO_SECONDARY_CRONJOB_TARGET, contents, perms=0o644) systemd.service_restart("cron") - self.unit.status = MaintenanceStatus("Mail pool synchronization cronjob has been set up") + + # -- Actions -------------------------------------------------------------- + + def _on_clear_queue_action(self, event): + """Handle the clear-queue action.""" + queue_to_clear = event.params.get("queue", "deferred") + + if queue_to_clear not in ("deferred", "all"): + event.fail("Invalid queue parameter, must be 'deferred' or 'all'") + return + command = ["postsuper", "-d", "ALL"] + + if queue_to_clear == "all": + logger.warning("Running clear-queue action: DELETING ALL mail from Postfix queue.") + else: + command.append("deferred") + logger.info("Running clear-queue action: Deleting deferred mail from Postfix queue.") + + try: + # The command and arguments are fixed literals with no user-controlled input. + result = subprocess.run(command, check=True, capture_output=True, text=True) + event.set_results({"status": "success", "output": result.stdout}) + except subprocess.CalledProcessError as e: + logger.exception(f"Failed to clear Postfix queue: {e.stderr}") + event.fail(f"Failed to run postsuper: {e.stderr}") def _on_force_sync(self, event): """Force synchronization with secondary unit.""" @@ -387,56 +481,15 @@ def _on_force_sync(self, event): return try: - cmd = [self.sync_to_secondary_target] + cmd = [SYNC_TO_SECONDARY_TARGET] logger.info(f"Running manual sync: {' '.join(cmd)}") - subprocess.run(cmd, check=True, capture_output=True, text=True) + subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603 event.set_results({"result": "Sync completed successfully"}) except subprocess.CalledProcessError as e: msg = f"Sync failed: {e.stderr}" logger.error(msg) event.fail(msg) - def _setup_tls(self, dovecot_config: DovecotConfig) -> None: - """Write TLS cert+key to disk from the certificates relation. - - Called from _reconcile before _setup_dovecot so the cert files are - present when dovecot.conf is rendered and validated. - - Raises: - ConfigurationError: If no TLS relation exists or the certificate - has not been issued yet. - """ - if not self._tls: - raise ConfigurationError( - "TLS certificates relation not available. " - "Integrate with a TLS provider using the 'certificates' relation." - ) - - cert_request = CertificateRequestAttributes( - common_name=dovecot_config.mailname, - sans_dns=frozenset([dovecot_config.mailname]), - ) - provider_cert, private_key = self._tls.get_assigned_certificate(cert_request) - if not provider_cert or not private_key: - raise ConfigurationError( - "TLS certificate not yet available from the certificates relation." - ) - - TLS_CERT_DIR.mkdir(parents=True, exist_ok=True) - cert_path = TLS_CERT_DIR / f"{dovecot_config.mailname}.pem" - key_path = TLS_CERT_DIR / f"{dovecot_config.mailname}.key" - - cert_content = str(provider_cert.certificate) - if provider_cert.ca: - cert_content += "\n" + str(provider_cert.ca) - cert_path.write_text(cert_content) - cert_path.chmod(0o644) - logger.info(f"TLS certificate written to {cert_path}") - - key_path.write_text(str(private_key)) - key_path.chmod(0o600) - logger.info(f"TLS private key written to {key_path}") - if __name__ == "__main__": # pragma: nocover main(DovecotCharm) diff --git a/dovecot-charm/tests/integration/test_ha.py b/dovecot-charm/tests/integration/test_ha.py index c20dccb..4fccd38 100644 --- a/dovecot-charm/tests/integration/test_ha.py +++ b/dovecot-charm/tests/integration/test_ha.py @@ -1,4 +1,4 @@ -# Copyright 2024 Canonical Ltd. +# Copyright 2026 Canonical Ltd. # See LICENSE file for licensing details. import logging @@ -77,6 +77,10 @@ def two_units_active(status): logging.info("Running force-sync on Primary...") + # Create a test Maildir so the sync script has something to sync. + # Without this, the script exits 1 because no Maildir directories exist. + juju.exec("mkdir -p /srv/mail/testuser/Maildir/{new,cur,tmp}", unit=primary) + task = juju.run(unit=primary, action="force-sync", wait=100) assert task.status == "completed" assert task.results["result"] == "Sync completed successfully" diff --git a/dovecot-charm/tests/unit/test_charm.py b/dovecot-charm/tests/unit/test_charm.py index 9693a42..4942018 100644 --- a/dovecot-charm/tests/unit/test_charm.py +++ b/dovecot-charm/tests/unit/test_charm.py @@ -1,5 +1,6 @@ # Copyright 2026 Canonical Ltd. # See LICENSE file for licensing details. +import contextlib import dataclasses from subprocess import CalledProcessError # nosec from unittest.mock import MagicMock, PropertyMock, patch @@ -12,42 +13,71 @@ from exceptions import ConfigurationError -def test_open_ports(ctx, base_state): +# --------------------------------------------------------------------------- +# Helpers — patches shared across many tests +# --------------------------------------------------------------------------- + +@contextlib.contextmanager +def reconcile_guards(): + """Guard all I/O in _reconcile so tests only exercise event wiring / status. + + Use when the test drives an event that triggers _reconcile but the test + is NOT about the logic inside these helpers (storage, TLS, dovecot, etc.). + """ with ( - # Guard real storage/TLS/dovecot operations so only port logic is exercised + # storage module talks to cryptsetup / mount — not under test patch("charm.ensure_storage_ready"), patch("charm.teardown_detaching_storage"), + # doveconf binary check — pretend it's installed patch("charm.shutil.which", return_value="/usr/bin/doveconf"), + # TLS writes cert/key files to disk — not under test patch("charm.DovecotCharm._setup_tls"), + # dovecot config rendering + validation + reload — not under test patch("charm.DovecotCharm._setup_dovecot"), + # procmail config rendering + postfix postconf — not under test patch("charm.DovecotCharm._setup_procmail"), + # SSH keygen + filesystem writes — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + # authorized_keys sync — not under test + patch("charm.DovecotCharm._sync_authorized_keys"), + # sync script rendering — not under test + patch("charm.DovecotCharm._install_mail_sync_script"), + # cronjob rendering + cron restart — not under test + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): - state_out = ctx.run(ctx.on.config_changed(), base_state) + yield - expected = {ops.testing.TCPPort(p) for p in [993, 995, 4190, 9900]} - assert state_out.opened_ports == expected +# --------------------------------------------------------------------------- +# Reconcile: status + ports +# --------------------------------------------------------------------------- -def test_configure_sets_active_on_success(ctx, base_state): - with ( - patch("charm.ensure_storage_ready"), - patch("charm.teardown_detaching_storage"), - patch("charm.shutil.which", return_value="/usr/bin/doveconf"), - patch("charm.DovecotCharm._setup_tls"), - patch("charm.DovecotCharm._setup_dovecot"), - patch("charm.DovecotCharm._setup_procmail"), - ): + +def test_reconcile_sets_active_on_success(ctx, base_state): + """Reconcile must reach ActiveStatus when all setup steps succeed.""" + with reconcile_guards(): state_out = ctx.run(ctx.on.config_changed(), base_state) assert isinstance(state_out.unit_status, ops.ActiveStatus) -def test_configure_blocks_when_dovecot_setup_fails(ctx, base_state): +def test_reconcile_opens_mail_ports(ctx, base_state): + """All required IMAP/POP3/Sieve/metrics ports must be opened.""" + with reconcile_guards(): + state_out = ctx.run(ctx.on.config_changed(), base_state) + + expected = {ops.testing.TCPPort(p) for p in [993, 995, 4190, 9900]} + assert state_out.opened_ports == expected + + +def test_reconcile_blocks_when_dovecot_setup_fails(ctx, base_state): + """Charm must be Blocked when _setup_dovecot raises ConfigurationError.""" with ( patch("charm.ensure_storage_ready"), patch("charm.teardown_detaching_storage"), patch("charm.shutil.which", return_value="/usr/bin/doveconf"), patch("charm.DovecotCharm._setup_tls"), + # _setup_dovecot raises — this is the condition under test patch( "charm.DovecotCharm._setup_dovecot", side_effect=ConfigurationError( @@ -62,13 +92,15 @@ def test_configure_blocks_when_dovecot_setup_fails(ctx, base_state): assert "Invalid Dovecot configuration" in state_out.unit_status.message -def test_configure_blocks_when_procmail_setup_fails(ctx, base_state): +def test_reconcile_blocks_when_procmail_setup_fails(ctx, base_state): + """Charm must be Blocked when _setup_procmail raises ConfigurationError.""" with ( patch("charm.ensure_storage_ready"), patch("charm.teardown_detaching_storage"), patch("charm.shutil.which", return_value="/usr/bin/doveconf"), patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), + # _setup_procmail raises — this is the condition under test patch( "charm.DovecotCharm._setup_procmail", side_effect=ConfigurationError("Failed to configure postfix: error"), @@ -80,12 +112,86 @@ def test_configure_blocks_when_procmail_setup_fails(ctx, base_state): assert "postfix" in state_out.unit_status.message -# --- Clear-queue action tests --- +# --------------------------------------------------------------------------- +# HA: _is_primary +# --------------------------------------------------------------------------- + + +def test_is_primary_true_when_unit_matches_config(ctx, base_state): + """_is_primary returns True when primary-unit config matches this unit.""" + # base_state has primary-unit=dovecot-charm/0; the ctx app_name gives unit dovecot-charm/0 + with reconcile_guards(), ctx(ctx.on.config_changed(), base_state) as mgr: + assert mgr.charm._is_primary is True + + +def test_is_primary_false_when_unit_differs(ctx, base_state): + """_is_primary returns False when primary-unit config doesn't match this unit. + + We access the charm inside the context manager before the event fires, + so no _reconcile I/O is reached — no patches needed for the HA methods. + Config validation is bypassed by patching _get_dovecot_config. + """ + state_in = dataclasses.replace( + base_state, config={**base_state.config, "primary-unit": "dovecot-charm/99"} + ) + with ( + # config validation rejects unknown units — bypass it since we're only testing _is_primary + patch("charm.DovecotCharm._get_dovecot_config"), + patch("charm.ensure_storage_ready"), + patch("charm.teardown_detaching_storage"), + patch("charm.shutil.which", return_value=None), + ctx(ctx.on.config_changed(), state_in) as mgr, + ): + assert mgr.charm._is_primary is False + + +# --------------------------------------------------------------------------- +# HA: reconcile calls sync script only on primary with known secondary +# --------------------------------------------------------------------------- + + +def test_reconcile_skips_sync_script_when_not_primary(ctx, base_state): + """When this unit is NOT primary, sync script and cronjob are not installed.""" + # Use a valid config but override _is_primary to False to bypass pydantic + # validation (which requires primary-unit to match an existing unit). + with ( + patch("charm.ensure_storage_ready"), + patch("charm.teardown_detaching_storage"), + patch("charm.shutil.which", return_value="/usr/bin/doveconf"), + patch("charm.DovecotCharm._setup_tls"), + patch("charm.DovecotCharm._setup_dovecot"), + patch("charm.DovecotCharm._setup_procmail"), + # ssh keygen — real subprocess not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + # authorized_keys sync — not under test + patch("charm.DovecotCharm._sync_authorized_keys"), + # Override _is_primary to simulate being a non-primary unit + patch("charm.DovecotCharm._is_primary", new_callable=PropertyMock, return_value=False), + # These should NOT be called — we verify via state not mocks + patch("charm.DovecotCharm._install_mail_sync_script") as mock_sync, + patch("charm.DovecotCharm._setup_mail_sync_cronjob") as mock_cron, + ): + state_out = ctx.run(ctx.on.config_changed(), base_state) + + # Charm still reaches Active even without sync scripts + assert isinstance(state_out.unit_status, ops.ActiveStatus) + # Secondary check: these should not have been called since unit is not primary + mock_sync.assert_not_called() + mock_cron.assert_not_called() + + +# --------------------------------------------------------------------------- +# Clear-queue action +# --------------------------------------------------------------------------- def test_clear_queue_deferred(ctx, base_state): + """clear-queue action with queue=deferred passes correct args to postsuper.""" mock_result = MagicMock(stdout="cleared") - with patch("charm.subprocess.run", return_value=mock_result) as mock_run: + with ( + # postsuper is the only subprocess call in this action path + patch("charm.subprocess.run", return_value=mock_result) as mock_run, + ): ctx.run( ctx.on.action("clear-queue", params={"queue": "deferred"}), base_state, @@ -100,8 +206,12 @@ def test_clear_queue_deferred(ctx, base_state): def test_clear_queue_all(ctx, base_state): + """clear-queue action with queue=all omits the deferred queue filter.""" mock_result = MagicMock(stdout="cleared") - with patch("charm.subprocess.run", return_value=mock_result) as mock_run: + with ( + # postsuper is the only subprocess call in this action path + patch("charm.subprocess.run", return_value=mock_result) as mock_run, + ): ctx.run( ctx.on.action("clear-queue", params={"queue": "all"}), base_state, @@ -116,7 +226,9 @@ def test_clear_queue_all(ctx, base_state): def test_clear_queue_failure(ctx, base_state): + """clear-queue action must fail when postsuper returns non-zero.""" with ( + # simulate postsuper failure patch( "charm.subprocess.run", side_effect=CalledProcessError(1, "postsuper", stderr="error msg"), @@ -130,45 +242,18 @@ def test_clear_queue_failure(ctx, base_state): assert "postsuper" in exc_info.value.message -def test_install_calls_all_setup_steps(ctx, base_state): - with ( - patch("charm.apt") as mock_apt, - patch("charm.shutil.copy") as mock_copy, - patch("charm.DovecotCharm._open_ports") as mock_open_ports, - patch("charm.DovecotCharm._setup_dovecot") as mock_dovecot, - patch("charm.DovecotCharm._setup_procmail") as mock_procmail, - patch("charm.DovecotCharm._setup_ssh_keys") as mock_setup_ssh, - patch("charm.DovecotCharm._install_mail_sync_script"), - patch("charm.DovecotCharm._setup_mail_sync_cronjob"), - ): - ctx.run(ctx.on.install(), base_state) - - mock_apt.update.assert_called_once() - mock_apt.add_package.assert_called_once() - mock_copy.assert_called_once_with("/etc/hostname", "/etc/mailname") - mock_open_ports.assert_called_once() - mock_dovecot.assert_called_once() - mock_procmail.assert_called_once() - mock_setup_ssh.assert_called_once() - - -def test_is_primary_true(ctx, base_state): - with patch("charm.DovecotCharm._install"), ctx(ctx.on.config_changed(), base_state) as mgr: - assert mgr.charm._is_primary is True - - -def test_is_primary_false(ctx, base_state): - state_in = dataclasses.replace( - base_state, config={**base_state.config, "primary-unit": "dovecot-charm/999"} - ) - with patch("charm.DovecotCharm._install"), ctx(ctx.on.config_changed(), state_in) as mgr: - assert mgr.charm._is_primary is False +# --------------------------------------------------------------------------- +# Force-sync action +# --------------------------------------------------------------------------- def test_force_sync_success(ctx, base_state): + """force-sync succeeds when this unit is primary and a secondary exists.""" mock_result = MagicMock(stdout="ok", stderr="") with ( + # sync script subprocess call — the action delegates to the shell script patch("charm.subprocess.run", return_value=mock_result), + # provide a secondary hostname so the action doesn't bail out patch.object( DovecotCharm, "_secondary_hostname", @@ -181,26 +266,32 @@ def test_force_sync_success(ctx, base_state): def test_force_sync_not_primary(ctx, base_state): - state_in = dataclasses.replace( - base_state, config={**base_state.config, "primary-unit": "dovecot-charm/999"} - ) - with pytest.raises(ops.testing.ActionFailed) as exc_info: - ctx.run(ctx.on.action("force-sync"), state_in) + """force-sync must fail when executed on a non-primary unit.""" + # Override _is_primary since pydantic rejects unknown unit names + with ( + patch("charm.DovecotCharm._is_primary", new_callable=PropertyMock, return_value=False), + pytest.raises(ops.testing.ActionFailed) as exc_info, + ): + ctx.run(ctx.on.action("force-sync"), base_state) assert "primary unit" in exc_info.value.message def test_force_sync_no_secondary(ctx, base_state): + """force-sync must fail when no secondary unit hostname is available.""" with pytest.raises(ops.testing.ActionFailed) as exc_info: ctx.run(ctx.on.action("force-sync"), base_state) assert "secondary" in exc_info.value.message def test_force_sync_subprocess_failure(ctx, base_state): + """force-sync must fail when the sync script exits non-zero.""" with ( + # sync script fails patch( "charm.subprocess.run", side_effect=CalledProcessError(1, "sync", stderr="fail"), ), + # provide secondary so the action reaches subprocess.run patch.object( DovecotCharm, "_secondary_hostname", diff --git a/dovecot-charm/tests/unit/test_storage.py b/dovecot-charm/tests/unit/test_storage.py index c204655..4306c52 100644 --- a/dovecot-charm/tests/unit/test_storage.py +++ b/dovecot-charm/tests/unit/test_storage.py @@ -29,6 +29,11 @@ def test_start_uses_saved_dev_path_when_model_error(ctx, base_state): patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O (ssh-keygen, authorized_keys, sync scripts) + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), patch("ops._main._Dispatcher.run_any_legacy_hook"), ): state_out = ctx.run(ctx.on.start(), state_in) @@ -72,6 +77,11 @@ def test_storage_attached_luks_auto_provisioning_disabled_mounted_is_active(ctx, patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_attached(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) @@ -105,6 +115,11 @@ def test_storage_attached_calls_setup_luks_with_key(ctx, base_state): patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_attached(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) @@ -125,6 +140,11 @@ def test_storage_attached_saves_dev_path(ctx, base_state): patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_attached(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) @@ -184,6 +204,11 @@ def test_storage_detaching_unmount_and_close(ctx, base_state): patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_detaching(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) @@ -219,6 +244,11 @@ def test_storage_detaching_luks_disabled_skips_close(ctx, base_state): patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_detaching(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) diff --git a/dovecot-charm/tests/unit/test_tls.py b/dovecot-charm/tests/unit/test_tls.py index b31756d..4a28491 100644 --- a/dovecot-charm/tests/unit/test_tls.py +++ b/dovecot-charm/tests/unit/test_tls.py @@ -49,6 +49,11 @@ def test_setup_tls_writes_cert_key_and_chain(ctx, base_state, tmp_path): # Isolate from dovecot/procmail filesystem writes patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O (ssh-keygen, authorized_keys, sync scripts) + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ctx(ctx.on.config_changed(), base_state) as mgr, ): # Override the TLS library instance so get_assigned_certificate @@ -87,6 +92,11 @@ def test_setup_tls_no_ca_omits_chain(ctx, base_state, tmp_path): patch("charm.shutil.which", return_value="/usr/bin/doveconf"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ctx(ctx.on.config_changed(), base_state) as mgr, ): mgr.charm._tls = MagicMock() @@ -143,6 +153,11 @@ def test_certificate_available_event_triggers_reconcile(ctx, base_state, tmp_pat ), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + # HA methods do filesystem I/O — not under test + patch("charm.DovecotCharm._setup_ssh_keys"), + patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._install_mail_sync_script"), + patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): # Fire certificate_available via config_changed (same handler) state_out = ctx.run(ctx.on.config_changed(), base_state) From 0b2026418d8f150ca485f0e55c90ce44eb20c012 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 13:47:29 +0300 Subject: [PATCH 25/39] feat(tests): add integration tests for high availability support --- .github/workflows/integration_test.yaml | 1 + docs/release-notes/artifacts/pr-4-ha.yaml | 22 ++++++++++++++++++---- dovecot-charm/src/charm.py | 6 +++--- dovecot-charm/tests/unit/test_charm.py | 2 +- 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/.github/workflows/integration_test.yaml b/.github/workflows/integration_test.yaml index 92261d6..3d11bfa 100644 --- a/.github/workflows/integration_test.yaml +++ b/.github/workflows/integration_test.yaml @@ -26,6 +26,7 @@ jobs: "test_mail.py", "test_storage.py", "test_tls.py", + "test_ha.py", ] allure-report: if: ${{ !cancelled() && github.event_name == 'schedule' }} diff --git a/docs/release-notes/artifacts/pr-4-ha.yaml b/docs/release-notes/artifacts/pr-4-ha.yaml index 300a046..0b7b284 100644 --- a/docs/release-notes/artifacts/pr-4-ha.yaml +++ b/docs/release-notes/artifacts/pr-4-ha.yaml @@ -1,4 +1,18 @@ -name: pr-4-ha -type: major -summary: HA support with SSH key exchange and force-sync action -url: https://github.com/canonical/mailserver-operators/pull/4 +# Copyright 2026 Canonical Ltd. +# See LICENSE file for licensing details. + +# Version of the artifact schema +version_schema: 2 + +changes: +- title: Added HA support with SSH key exchange and force-sync action + author: alithethird + type: major + description: Added high availability support for Dovecot with automatic SSH key exchange between primary and secondary units via the replicas peer relation, rsync-based mail synchronization via cron, and a force-sync Juju action for on-demand replication. + urls: + pr: + - "https://github.com/canonical/mailserver-operators/pull/15" + related_doc: + related_issue: + visibility: public + highlight: true diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index bc719fe..6e46089 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -345,8 +345,8 @@ def _setup_ssh_keys(self): key_file = SSH_DIR / "id_ed25519" if not key_file.exists(): - subprocess.run( # noqa: S603 - ["ssh-keygen", "-t", "ed25519", "-N", "", "-f", str(key_file)], + subprocess.run( + ["/usr/bin/ssh-keygen", "-t", "ed25519", "-N", "", "-f", str(key_file)], check=True, capture_output=True, ) @@ -483,7 +483,7 @@ def _on_force_sync(self, event): try: cmd = [SYNC_TO_SECONDARY_TARGET] logger.info(f"Running manual sync: {' '.join(cmd)}") - subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603 + subprocess.run(cmd, check=True, capture_output=True, text=True) event.set_results({"result": "Sync completed successfully"}) except subprocess.CalledProcessError as e: msg = f"Sync failed: {e.stderr}" diff --git a/dovecot-charm/tests/unit/test_charm.py b/dovecot-charm/tests/unit/test_charm.py index 4942018..1d357cd 100644 --- a/dovecot-charm/tests/unit/test_charm.py +++ b/dovecot-charm/tests/unit/test_charm.py @@ -12,11 +12,11 @@ from charm import DovecotCharm from exceptions import ConfigurationError - # --------------------------------------------------------------------------- # Helpers — patches shared across many tests # --------------------------------------------------------------------------- + @contextlib.contextmanager def reconcile_guards(): """Guard all I/O in _reconcile so tests only exercise event wiring / status. From 4854b3dd55570da6c5ee343c95195705303fa269 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Mon, 20 Apr 2026 15:24:13 +0300 Subject: [PATCH 26/39] feat(ha): add known_hosts synchronization for SSH key exchange --- dovecot-charm/src/charm.py | 45 ++++++++++++++++++++---- dovecot-charm/tests/unit/test_charm.py | 4 +++ dovecot-charm/tests/unit/test_storage.py | 6 ++++ dovecot-charm/tests/unit/test_tls.py | 3 ++ 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index 6e46089..ed826bf 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -51,6 +51,7 @@ SSHD_CONFIG = Path("/etc/ssh/sshd_config") SSH_DIR = Path("/root/.ssh") +SSH_HOST_KEY_FILE = Path("/etc/ssh/ssh_host_ed25519_key.pub") class DovecotCharm(CharmBase): @@ -194,9 +195,10 @@ def _reconcile(self, event): except ConfigurationError as e: self.unit.status = BlockedStatus(str(e)) return - # HA: SSH keys, authorized_keys, sync script + cronjob + # HA: SSH keys, authorized_keys, known_hosts, sync script + cronjob self._setup_ssh_keys() self._sync_authorized_keys() + self._sync_known_hosts() if self._is_primary: self._install_mail_sync_script() self._setup_mail_sync_cronjob() @@ -340,7 +342,12 @@ def _setup_tls(self, dovecot_config: DovecotConfig) -> None: # -- HA / SSH key exchange ------------------------------------------------ def _setup_ssh_keys(self): - """Generate an SSH key pair if absent and publish the public key via the peer relation.""" + """Generate an SSH key pair if absent and publish keys via the peer relation. + + Publishes both the user public key (for authorized_keys) and the host + public key (for known_hosts) so peers can verify each other's identity + without disabling StrictHostKeyChecking. + """ SSH_DIR.mkdir(mode=0o700, exist_ok=True) key_file = SSH_DIR / "id_ed25519" @@ -362,10 +369,10 @@ def _setup_ssh_keys(self): relation.data[self.unit]["public_key"] = pub_key relation.data[self.unit]["hostname"] = socket.gethostname() - config_file = SSH_DIR / "config" - if not config_file.exists(): - config_file.write_text("Host *\n StrictHostKeyChecking no\n") - config_file.chmod(0o600) + # Publish the host public key so peers can populate known_hosts + if SSH_HOST_KEY_FILE.exists(): + host_key = SSH_HOST_KEY_FILE.read_text().strip() + relation.data[self.unit]["ssh_host_key"] = host_key def _sync_authorized_keys(self): """Collect public keys from all peer units and write authorized_keys.""" @@ -392,6 +399,32 @@ def _sync_authorized_keys(self): self._ensure_root_ssh_login() + def _sync_known_hosts(self): + """Populate known_hosts with peer SSH host keys from the peer relation. + + Each peer publishes its host public key and hostname on the relation. + This method writes those into known_hosts so SSH connections between + units use StrictHostKeyChecking (the default) instead of disabling it. + """ + relation = self.model.get_relation(PEER_RELATION_NAME) + if not relation: + return + + entries = [] + for unit in relation.units: + host_key = relation.data[unit].get("ssh_host_key") + hostname = relation.data[unit].get("hostname") + if host_key and hostname: + # known_hosts format: + entries.append(f"{hostname} {host_key}") + + if not entries: + return + + known_hosts_file = SSH_DIR / "known_hosts" + known_hosts_file.write_text("\n".join(entries) + "\n") + known_hosts_file.chmod(0o600) + def _ensure_root_ssh_login(self): """Set PermitRootLogin to prohibit-password in sshd_config and reload sshd.""" if SSHD_CONFIG.exists(): diff --git a/dovecot-charm/tests/unit/test_charm.py b/dovecot-charm/tests/unit/test_charm.py index 1d357cd..3d17885 100644 --- a/dovecot-charm/tests/unit/test_charm.py +++ b/dovecot-charm/tests/unit/test_charm.py @@ -40,6 +40,8 @@ def reconcile_guards(): patch("charm.DovecotCharm._setup_ssh_keys"), # authorized_keys sync — not under test patch("charm.DovecotCharm._sync_authorized_keys"), + # known_hosts sync — not under test + patch("charm.DovecotCharm._sync_known_hosts"), # sync script rendering — not under test patch("charm.DovecotCharm._install_mail_sync_script"), # cronjob rendering + cron restart — not under test @@ -165,6 +167,8 @@ def test_reconcile_skips_sync_script_when_not_primary(ctx, base_state): patch("charm.DovecotCharm._setup_ssh_keys"), # authorized_keys sync — not under test patch("charm.DovecotCharm._sync_authorized_keys"), + # known_hosts sync — not under test + patch("charm.DovecotCharm._sync_known_hosts"), # Override _is_primary to simulate being a non-primary unit patch("charm.DovecotCharm._is_primary", new_callable=PropertyMock, return_value=False), # These should NOT be called — we verify via state not mocks diff --git a/dovecot-charm/tests/unit/test_storage.py b/dovecot-charm/tests/unit/test_storage.py index 4306c52..106460f 100644 --- a/dovecot-charm/tests/unit/test_storage.py +++ b/dovecot-charm/tests/unit/test_storage.py @@ -32,6 +32,7 @@ def test_start_uses_saved_dev_path_when_model_error(ctx, base_state): # HA methods do filesystem I/O (ssh-keygen, authorized_keys, sync scripts) patch("charm.DovecotCharm._setup_ssh_keys"), patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._sync_known_hosts"), patch("charm.DovecotCharm._install_mail_sync_script"), patch("charm.DovecotCharm._setup_mail_sync_cronjob"), patch("ops._main._Dispatcher.run_any_legacy_hook"), @@ -80,6 +81,7 @@ def test_storage_attached_luks_auto_provisioning_disabled_mounted_is_active(ctx, # HA methods do filesystem I/O — not under test patch("charm.DovecotCharm._setup_ssh_keys"), patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._sync_known_hosts"), patch("charm.DovecotCharm._install_mail_sync_script"), patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): @@ -118,6 +120,7 @@ def test_storage_attached_calls_setup_luks_with_key(ctx, base_state): # HA methods do filesystem I/O — not under test patch("charm.DovecotCharm._setup_ssh_keys"), patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._sync_known_hosts"), patch("charm.DovecotCharm._install_mail_sync_script"), patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): @@ -143,6 +146,7 @@ def test_storage_attached_saves_dev_path(ctx, base_state): # HA methods do filesystem I/O — not under test patch("charm.DovecotCharm._setup_ssh_keys"), patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._sync_known_hosts"), patch("charm.DovecotCharm._install_mail_sync_script"), patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): @@ -207,6 +211,7 @@ def test_storage_detaching_unmount_and_close(ctx, base_state): # HA methods do filesystem I/O — not under test patch("charm.DovecotCharm._setup_ssh_keys"), patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._sync_known_hosts"), patch("charm.DovecotCharm._install_mail_sync_script"), patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): @@ -247,6 +252,7 @@ def test_storage_detaching_luks_disabled_skips_close(ctx, base_state): # HA methods do filesystem I/O — not under test patch("charm.DovecotCharm._setup_ssh_keys"), patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._sync_known_hosts"), patch("charm.DovecotCharm._install_mail_sync_script"), patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): diff --git a/dovecot-charm/tests/unit/test_tls.py b/dovecot-charm/tests/unit/test_tls.py index 4a28491..3a20c00 100644 --- a/dovecot-charm/tests/unit/test_tls.py +++ b/dovecot-charm/tests/unit/test_tls.py @@ -52,6 +52,7 @@ def test_setup_tls_writes_cert_key_and_chain(ctx, base_state, tmp_path): # HA methods do filesystem I/O (ssh-keygen, authorized_keys, sync scripts) patch("charm.DovecotCharm._setup_ssh_keys"), patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._sync_known_hosts"), patch("charm.DovecotCharm._install_mail_sync_script"), patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ctx(ctx.on.config_changed(), base_state) as mgr, @@ -95,6 +96,7 @@ def test_setup_tls_no_ca_omits_chain(ctx, base_state, tmp_path): # HA methods do filesystem I/O — not under test patch("charm.DovecotCharm._setup_ssh_keys"), patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._sync_known_hosts"), patch("charm.DovecotCharm._install_mail_sync_script"), patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ctx(ctx.on.config_changed(), base_state) as mgr, @@ -156,6 +158,7 @@ def test_certificate_available_event_triggers_reconcile(ctx, base_state, tmp_pat # HA methods do filesystem I/O — not under test patch("charm.DovecotCharm._setup_ssh_keys"), patch("charm.DovecotCharm._sync_authorized_keys"), + patch("charm.DovecotCharm._sync_known_hosts"), patch("charm.DovecotCharm._install_mail_sync_script"), patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): From dc37e92eadeabed880b42871ec7dfd0318587506 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Tue, 21 Apr 2026 10:54:55 +0300 Subject: [PATCH 27/39] feat(ha): ensure system user exists for doveadm user lookup in sync script --- dovecot-charm/tests/integration/test_ha.py | 31 +++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/dovecot-charm/tests/integration/test_ha.py b/dovecot-charm/tests/integration/test_ha.py index 4fccd38..5211df9 100644 --- a/dovecot-charm/tests/integration/test_ha.py +++ b/dovecot-charm/tests/integration/test_ha.py @@ -77,9 +77,34 @@ def two_units_active(status): logging.info("Running force-sync on Primary...") - # Create a test Maildir so the sync script has something to sync. - # Without this, the script exits 1 because no Maildir directories exist. - juju.exec("mkdir -p /srv/mail/testuser/Maildir/{new,cur,tmp}", unit=primary) + # Ensure a real system user exists for doveadm user lookup. + # A bare /srv/mail/ directory is not enough for dsync. + sync_user = "syncuser" + for unit in (primary, secondary): + juju.exec("rm -rf /srv/mail/syncuser /srv/mail/sync-* /srv/mail/testuser", unit=unit) + + juju.exec( + ( + f"id -u {sync_user} >/dev/null 2>&1 || " + f"useradd -M -d /srv/mail/{sync_user} -s /usr/sbin/nologin {sync_user}" + ), + unit=primary, + ) + juju.exec( + ( + f"mkdir -p /srv/mail/{sync_user}/Maildir/{{new,cur,tmp}} && " + f"chown -R {sync_user}:{sync_user} /srv/mail/{sync_user} && " + f"chmod 700 /srv/mail/{sync_user} /srv/mail/{sync_user}/Maildir" + ), + unit=primary, + ) + juju.exec( + ( + f"id -u {sync_user} >/dev/null 2>&1 || " + f"useradd -M -d /srv/mail/{sync_user} -s /usr/sbin/nologin {sync_user}" + ), + unit=secondary, + ) task = juju.run(unit=primary, action="force-sync", wait=100) assert task.status == "completed" From 77719960abf844d641e4b7fd8abc62d70c18abf1 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Tue, 21 Apr 2026 11:19:48 +0300 Subject: [PATCH 28/39] refactor(tests): streamline TLS setup tests by removing unnecessary patches --- dovecot-charm/tests/unit/test_charm.py | 8 -------- dovecot-charm/tests/unit/test_tls.py | 13 +------------ 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/dovecot-charm/tests/unit/test_charm.py b/dovecot-charm/tests/unit/test_charm.py index 6e68e7a..2170f3a 100644 --- a/dovecot-charm/tests/unit/test_charm.py +++ b/dovecot-charm/tests/unit/test_charm.py @@ -39,19 +39,11 @@ def reconcile_guards(): ): yield - expected = {ops.testing.TCPPort(p) for p in [993, 995, 4190, 9900]} - assert state_out.opened_ports == expected - -# --------------------------------------------------------------------------- -# Reconcile: status + ports -# --------------------------------------------------------------------------- - def test_reconcile_sets_active_on_success(ctx, base_state): """Reconcile must reach ActiveStatus when all setup steps succeed.""" with reconcile_guards(): state_out = ctx.run(ctx.on.config_changed(), base_state) - assert isinstance(state_out.unit_status, ops.ActiveStatus) diff --git a/dovecot-charm/tests/unit/test_tls.py b/dovecot-charm/tests/unit/test_tls.py index e7f7093..86d95aa 100644 --- a/dovecot-charm/tests/unit/test_tls.py +++ b/dovecot-charm/tests/unit/test_tls.py @@ -42,13 +42,12 @@ def test_setup_tls_writes_cert_key_and_chain(ctx, base_state, tmp_path): mock_key.__str__ = MagicMock(return_value="KEY_DATA") with ( - # Redirect TLS_CERT_DIR so _setup_tls writes into tmp_path patch("charm.TLS_CERT_DIR", tmp_path), patch("charm.ensure_storage_ready"), patch("charm.shutil.which", return_value="/usr/bin/doveconf"), - # Isolate from dovecot/procmail filesystem writes patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), + patch("charm.DovecotCharm._setup_ssh_keys"), ctx(ctx.on.config_changed(), base_state) as mgr, ): # Override the TLS library instance so get_assigned_certificate @@ -87,12 +86,7 @@ def test_setup_tls_no_ca_omits_chain(ctx, base_state, tmp_path): patch("charm.shutil.which", return_value="/usr/bin/doveconf"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), - # HA methods do filesystem I/O — not under test patch("charm.DovecotCharm._setup_ssh_keys"), - patch("charm.DovecotCharm._sync_authorized_keys"), - patch("charm.DovecotCharm._sync_known_hosts"), - patch("charm.DovecotCharm._install_mail_sync_script"), - patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ctx(ctx.on.config_changed(), base_state) as mgr, ): mgr.charm._tls = MagicMock() @@ -149,12 +143,7 @@ def test_certificate_available_event_triggers_reconcile(ctx, base_state, tmp_pat ), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), - # HA methods do filesystem I/O — not under test patch("charm.DovecotCharm._setup_ssh_keys"), - patch("charm.DovecotCharm._sync_authorized_keys"), - patch("charm.DovecotCharm._sync_known_hosts"), - patch("charm.DovecotCharm._install_mail_sync_script"), - patch("charm.DovecotCharm._setup_mail_sync_cronjob"), ): # Fire certificate_available via config_changed (same handler) state_out = ctx.run(ctx.on.config_changed(), base_state) From c0e37e668758d43114b6e6bf2f4888e618db92e8 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Tue, 21 Apr 2026 11:47:51 +0300 Subject: [PATCH 29/39] refactor: remove redundant TLS setup method from DovecotCharm --- dovecot-charm/src/charm.py | 43 -------------------------------------- 1 file changed, 43 deletions(-) diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index a1ef709..994b500 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -298,49 +298,6 @@ def _setup_procmail(self) -> None: logger.exception(f"Failed to configure postfix: {e}") raise ConfigurationError(f"Failed to configure postfix: {e.stderr}") from e - def _setup_tls(self, dovecot_config: DovecotConfig) -> None: - """Write TLS cert+key to disk from the certificates relation. - - Called from _reconcile before _setup_dovecot so the cert files are - present when dovecot.conf is rendered and validated. - - Raises: - ConfigurationError: If no TLS relation exists or the certificate - has not been issued yet. - """ - if not self._tls: - raise ConfigurationError( - "TLS certificates relation not available. " - "Integrate with a TLS provider using the 'certificates' relation." - ) - - cert_request = CertificateRequestAttributes( - common_name=dovecot_config.mailname, - sans_dns=frozenset([dovecot_config.mailname]), - ) - provider_cert, private_key = self._tls.get_assigned_certificate(cert_request) - if not provider_cert or not private_key: - raise ConfigurationError( - "TLS certificate not yet available from the certificates relation." - ) - - TLS_CERT_DIR.mkdir(parents=True, exist_ok=True) - cert_path = TLS_CERT_DIR / f"{dovecot_config.mailname}.pem" - key_path = TLS_CERT_DIR / f"{dovecot_config.mailname}.key" - - cert_content = str(provider_cert.certificate) - if provider_cert.ca: - cert_content += "\n" + str(provider_cert.ca) - cert_path.write_text(cert_content) - cert_path.chmod(0o644) - logger.info(f"TLS certificate written to {cert_path}") - - key_path.write_text(str(private_key)) - key_path.chmod(0o600) - logger.info(f"TLS private key written to {key_path}") - - # -- HA / SSH key exchange ------------------------------------------------ - def _setup_ssh_keys(self): """Generate an SSH key pair if absent and publish keys via the peer relation. From ec1bb23c4da1ca3f97b8c6998ac2a2719aafd4a0 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Tue, 21 Apr 2026 12:58:30 +0300 Subject: [PATCH 30/39] feat(ha): implement high availability functions and refactor SSH key management --- dovecot-charm/src/charm.py | 192 ++--------------------- dovecot-charm/src/constants.py | 10 ++ dovecot-charm/src/ha.py | 173 ++++++++++++++++++++ dovecot-charm/tests/unit/test_charm.py | 20 +-- dovecot-charm/tests/unit/test_storage.py | 60 +++---- dovecot-charm/tests/unit/test_tls.py | 6 +- 6 files changed, 240 insertions(+), 221 deletions(-) create mode 100644 dovecot-charm/src/ha.py diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index 994b500..d67eda8 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -6,7 +6,6 @@ import logging import shutil -import socket import subprocess # nosec import typing from pathlib import Path @@ -23,6 +22,7 @@ from ops.main import main from ops.model import BlockedStatus, MaintenanceStatus +import ha from constants import ( DOVECOT_CONF_TARGET, DOVECOT_CONF_TEMPLATE, @@ -34,6 +34,7 @@ PROCMAILRC_TARGET, PROCMAILRC_TEMPLATE, REQUIRED_PACKAGES, + SYNC_TO_SECONDARY_TARGET, TEMPLATES_DIR, TLS_CERT_DIR, ) @@ -43,16 +44,6 @@ logger = logging.getLogger(__name__) -# HA sync paths -SYNC_TO_SECONDARY_TARGET = "/usr/local/bin/sync-to-secondary.sh" -SYNC_TO_SECONDARY_CRONJOB_TARGET = "/etc/cron.d/sync-to-secondary" -SYNC_TO_SECONDARY_TEMPLATE = "sync-to-secondary.sh.tmpl" -SYNC_TO_SECONDARY_CRONJOB_TEMPLATE = "sync-to-secondary_cron.tmpl" - -SSHD_CONFIG = Path("/etc/ssh/sshd_config") -SSH_DIR = Path("/root/.ssh") -SSH_HOST_KEY_FILE = Path("/etc/ssh/ssh_host_ed25519_key.pub") - class DovecotCharm(CharmBase): """Dovecot IMAP/POP3 mail server charm.""" @@ -60,7 +51,6 @@ class DovecotCharm(CharmBase): def __init__(self, *args): super().__init__(*args) - # Events — every event except install goes through _reconcile. self.framework.observe(self.on.install, self._on_install) self.framework.observe(self.on.start, self._reconcile) self.framework.observe(self.on.config_changed, self._reconcile) @@ -75,12 +65,11 @@ def __init__(self, *args): self.on[PEER_RELATION_NAME].relation_created, self._on_peer_relation_created, ) - # Template system + self.jinja = jinja2.Environment( loader=jinja2.FileSystemLoader(TEMPLATES_DIR), autoescape=True ) - # TLS certificates integration self._tls = None mailname = self.config.get("mailname", "") if mailname: @@ -121,26 +110,20 @@ def _on_peer_relation_created(self, event): relation_data["unit-name"] = self.unit.name @property - def _is_primary(self): + def _is_primary(self) -> bool: """Return True if this unit is the configured primary unit.""" return self.unit.name == self.config.get("primary-unit", "") @property def _secondary_hostname(self) -> typing.Optional[str]: - """Return the hostname/IP of the first remote peer unit, or None.""" + """Return the hostname of the first remote peer unit, or None.""" relation = self.model.get_relation(PEER_RELATION_NAME) if not relation: return None - for unit in relation.units: - hostname = ( - relation.data[unit].get("hostname") - or relation.data[unit].get("private-address") - or relation.data[unit].get("ingress-address") - ) + hostname = relation.data[unit].get("hostname") if hostname: return hostname - return None def _get_dovecot_config(self) -> DovecotConfig: @@ -164,19 +147,14 @@ def _get_dovecot_config(self) -> DovecotConfig: logger.exception(f"Secret retrieval error: {exc}") raise ConfigurationError(str(exc)) from exc - # -- Event handlers ------------------------------------------------------- - def _on_install(self, event): - """Handle install event — install packages only, then reconcile.""" + """Handle install event.""" self.unit.status = MaintenanceStatus("Installing packages") self._install() self._reconcile(event) def _reconcile(self, event): - """Reconcile charm state for every event except install. - - Holistic handler: storage → TLS → dovecot → procmail → HA → ports. - """ + """Reconcile charm state.""" self.unit.status = MaintenanceStatus("Configuring charm") try: dovecot_config = self._get_dovecot_config() @@ -195,28 +173,23 @@ def _reconcile(self, event): except ConfigurationError as e: self.unit.status = BlockedStatus(str(e)) return - # HA: SSH keys, authorized_keys, known_hosts, sync script + cronjob - self._setup_ssh_keys() - self._sync_authorized_keys() - self._sync_known_hosts() + ha.setup_ssh_keys(self) + ha.sync_authorized_keys(self) + ha.sync_known_hosts(self) if self._is_primary: - self._install_mail_sync_script() - self._setup_mail_sync_cronjob() + ha.install_mail_sync_script(self) + ha.setup_mail_sync_cronjob(self) self._open_ports() self.unit.status = ops.ActiveStatus() - # -- Installation --------------------------------------------------------- - def _install(self): - """Perform basic installation — packages and hostname only.""" + """Install required packages and set up mailname.""" self.unit.status = MaintenanceStatus("Installing required dependencies") apt.update() apt.add_package(REQUIRED_PACKAGES) shutil.copy(HOSTNAME_FILE, MAILNAME_FILE) self.unit.status = MaintenanceStatus("Charm installation done") - # -- Service configuration ------------------------------------------------ - def _open_ports(self): """Open mail ports (TLS-only: plaintext 143/110 are not exposed).""" self.unit.open_port("tcp", 993) @@ -298,143 +271,6 @@ def _setup_procmail(self) -> None: logger.exception(f"Failed to configure postfix: {e}") raise ConfigurationError(f"Failed to configure postfix: {e.stderr}") from e - def _setup_ssh_keys(self): - """Generate an SSH key pair if absent and publish keys via the peer relation. - - Publishes both the user public key (for authorized_keys) and the host - public key (for known_hosts) so peers can verify each other's identity - without disabling StrictHostKeyChecking. - """ - SSH_DIR.mkdir(mode=0o700, exist_ok=True) - key_file = SSH_DIR / "id_ed25519" - - if not key_file.exists(): - subprocess.run( - ["/usr/bin/ssh-keygen", "-t", "ed25519", "-N", "", "-f", str(key_file)], - check=True, - capture_output=True, - ) - - pub_key_file = SSH_DIR / "id_ed25519.pub" - if not pub_key_file.exists(): - logger.error("SSH public key file not found after key generation") - return - - pub_key = pub_key_file.read_text().strip() - relation = self.model.get_relation(PEER_RELATION_NAME) - if relation: - relation.data[self.unit]["public_key"] = pub_key - relation.data[self.unit]["hostname"] = socket.gethostname() - - # Publish the host public key so peers can populate known_hosts - if SSH_HOST_KEY_FILE.exists(): - host_key = SSH_HOST_KEY_FILE.read_text().strip() - relation.data[self.unit]["ssh_host_key"] = host_key - - def _sync_authorized_keys(self): - """Collect public keys from all peer units and write authorized_keys.""" - relation = self.model.get_relation(PEER_RELATION_NAME) - if not relation: - return - - authorized_keys = [] - for unit in relation.units: - pk = relation.data[unit].get("public_key") - if pk: - authorized_keys.append(pk) - - our_pk = relation.data[self.unit].get("public_key") - if our_pk: - authorized_keys.append(our_pk) - - if not authorized_keys: - return - - auth_file = SSH_DIR / "authorized_keys" - auth_file.write_text("\n".join(authorized_keys) + "\n") - auth_file.chmod(0o600) - - self._ensure_root_ssh_login() - - def _sync_known_hosts(self): - """Populate known_hosts with peer SSH host keys from the peer relation. - - Each peer publishes its host public key and hostname on the relation. - This method writes those into known_hosts so SSH connections between - units use StrictHostKeyChecking (the default) instead of disabling it. - """ - relation = self.model.get_relation(PEER_RELATION_NAME) - if not relation: - return - - entries = [] - for unit in relation.units: - host_key = relation.data[unit].get("ssh_host_key") - hostname = relation.data[unit].get("hostname") - if host_key and hostname: - # known_hosts format: - entries.append(f"{hostname} {host_key}") - - if not entries: - return - - known_hosts_file = SSH_DIR / "known_hosts" - known_hosts_file.write_text("\n".join(entries) + "\n") - known_hosts_file.chmod(0o600) - - def _ensure_root_ssh_login(self): - """Set PermitRootLogin to prohibit-password in sshd_config and reload sshd.""" - if SSHD_CONFIG.exists(): - content = SSHD_CONFIG.read_text() - new_content = "" - found = False - for line in content.splitlines(keepends=True): - stripped = line.lstrip("#").strip() - if stripped.startswith("PermitRootLogin"): - new_content += "PermitRootLogin prohibit-password\n" - found = True - else: - new_content += line - if not found: - new_content += "\nPermitRootLogin prohibit-password\n" - if new_content != content: - SSHD_CONFIG.write_text(new_content) - systemd.service_reload("ssh", restart_on_failure=True) - - def _install_mail_sync_script(self): - """Render and install the mail pool synchronization script. - - Skipped when the secondary hostname is not yet known (no remote peer). - """ - secondary = self._secondary_hostname - if not secondary: - logger.info("Secondary hostname not yet known; skipping sync script installation") - return - - self.unit.status = MaintenanceStatus("Installing mail pool synchronization script") - template_context = { - "secondary_hostname": secondary, - "mail_root": MAIL_ROOT, - } - template = self.jinja.get_template(SYNC_TO_SECONDARY_TEMPLATE) - contents = template.render(template_context) - host.write_file(SYNC_TO_SECONDARY_TARGET, contents, perms=0o755) - - def _setup_mail_sync_cronjob(self): - """Set up the mail pool synchronization cronjob.""" - if not self._secondary_hostname: - logger.info("Secondary hostname not yet known; skipping cronjob setup") - return - - self.unit.status = MaintenanceStatus("Setting up mail pool synchronization cronjob") - template_context = { - "schedule": self.config.get("sync-schedule", "*/30 * * * *"), - } - template = self.jinja.get_template(SYNC_TO_SECONDARY_CRONJOB_TEMPLATE) - contents = template.render(template_context) - host.write_file(SYNC_TO_SECONDARY_CRONJOB_TARGET, contents, perms=0o644) - systemd.service_restart("cron") - # -- Actions -------------------------------------------------------------- def _on_clear_queue_action(self, event): diff --git a/dovecot-charm/src/constants.py b/dovecot-charm/src/constants.py index 4df5950..30b4c1b 100644 --- a/dovecot-charm/src/constants.py +++ b/dovecot-charm/src/constants.py @@ -47,3 +47,13 @@ STORAGE_DEV_PATH_FILE = "/var/lib/dovecot-charm/storage-dev-path" TLS_CERT_DIR = Path("/etc/dovecot/private") + +# HA sync paths +SYNC_TO_SECONDARY_TARGET = "/usr/local/bin/sync-to-secondary.sh" +SYNC_TO_SECONDARY_CRONJOB_TARGET = "/etc/cron.d/sync-to-secondary" +SYNC_TO_SECONDARY_TEMPLATE = "sync-to-secondary.sh.tmpl" +SYNC_TO_SECONDARY_CRONJOB_TEMPLATE = "sync-to-secondary_cron.tmpl" + +SSHD_CONFIG = Path("/etc/ssh/sshd_config") +SSH_DIR = Path("/root/.ssh") +SSH_HOST_KEY_FILE = Path("/etc/ssh/ssh_host_ed25519_key.pub") diff --git a/dovecot-charm/src/ha.py b/dovecot-charm/src/ha.py new file mode 100644 index 0000000..6d6f378 --- /dev/null +++ b/dovecot-charm/src/ha.py @@ -0,0 +1,173 @@ +# Copyright 2026 Canonical Ltd. +# See LICENSE file for licensing details. + +"""High availability functions for the Dovecot charm.""" + +from __future__ import annotations + +import logging +import socket +import subprocess # nosec +import typing + +from charmhelpers.core import host +from charmlibs import systemd +from ops.model import MaintenanceStatus + +from constants import ( + MAIL_ROOT, + PEER_RELATION_NAME, + SSH_DIR, + SSH_HOST_KEY_FILE, + SSHD_CONFIG, + SYNC_TO_SECONDARY_CRONJOB_TARGET, + SYNC_TO_SECONDARY_CRONJOB_TEMPLATE, + SYNC_TO_SECONDARY_TARGET, + SYNC_TO_SECONDARY_TEMPLATE, +) + +if typing.TYPE_CHECKING: + from charm import DovecotCharm + +logger = logging.getLogger(__name__) + + +def setup_ssh_keys(charm: DovecotCharm) -> None: + """Generate an SSH key pair if absent and publish keys via the peer relation. + + Publishes both the user public key (for authorized_keys) and the host + public key (for known_hosts) so peers can verify each other's identity + without disabling StrictHostKeyChecking. + """ + SSH_DIR.mkdir(mode=0o700, exist_ok=True) + key_file = SSH_DIR / "id_ed25519" + + if not key_file.exists(): + subprocess.run( + ["/usr/bin/ssh-keygen", "-t", "ed25519", "-N", "", "-f", str(key_file)], + check=True, + capture_output=True, + ) + + pub_key_file = SSH_DIR / "id_ed25519.pub" + if not pub_key_file.exists(): + logger.error("SSH public key file not found after key generation") + return + + pub_key = pub_key_file.read_text().strip() + relation = charm.model.get_relation(PEER_RELATION_NAME) + if relation: + relation.data[charm.unit]["public_key"] = pub_key + relation.data[charm.unit]["hostname"] = socket.gethostname() + + if SSH_HOST_KEY_FILE.exists(): + host_key = SSH_HOST_KEY_FILE.read_text().strip() + relation.data[charm.unit]["ssh_host_key"] = host_key + + +def sync_authorized_keys(charm: DovecotCharm) -> None: + """Collect public keys from all peer units and write authorized_keys.""" + relation = charm.model.get_relation(PEER_RELATION_NAME) + if not relation: + return + + authorized_keys = [] + for unit in relation.units: + pk = relation.data[unit].get("public_key") + if pk: + authorized_keys.append(pk) + + our_pk = relation.data[charm.unit].get("public_key") + if our_pk: + authorized_keys.append(our_pk) + + if not authorized_keys: + return + + auth_file = SSH_DIR / "authorized_keys" + auth_file.write_text("\n".join(authorized_keys) + "\n") + auth_file.chmod(0o600) + + ensure_root_ssh_login() + + +def sync_known_hosts(charm: DovecotCharm) -> None: + """Populate known_hosts with peer SSH host keys from the peer relation. + + Each peer publishes its host public key and hostname on the relation. + This writes those into known_hosts so SSH connections between units use + StrictHostKeyChecking (the default) instead of disabling it. + """ + relation = charm.model.get_relation(PEER_RELATION_NAME) + if not relation: + return + + entries = [] + for unit in relation.units: + host_key = relation.data[unit].get("ssh_host_key") + hostname = relation.data[unit].get("hostname") + if host_key and hostname: + entries.append(f"{hostname} {host_key}") + + if not entries: + return + + known_hosts_file = SSH_DIR / "known_hosts" + known_hosts_file.write_text("\n".join(entries) + "\n") + known_hosts_file.chmod(0o600) + + +def ensure_root_ssh_login() -> None: + """Set PermitRootLogin to prohibit-password in sshd_config and reload sshd.""" + if SSHD_CONFIG.exists(): + content = SSHD_CONFIG.read_text() + new_content = "" + found = False + for line in content.splitlines(keepends=True): + stripped = line.lstrip("#").strip() + if stripped.startswith("PermitRootLogin"): + new_content += "PermitRootLogin prohibit-password\n" + found = True + else: + new_content += line + if not found: + new_content += "\nPermitRootLogin prohibit-password\n" + if new_content != content: + SSHD_CONFIG.write_text(new_content) + systemd.service_reload("ssh", restart_on_failure=True) + + +def install_mail_sync_script(charm: DovecotCharm) -> None: + """Render and install the mail pool synchronization script. + + Skipped when the secondary hostname is not yet known (no remote peer). + """ + secondary = charm._secondary_hostname + if not secondary: + logger.info("Secondary hostname not yet known; skipping sync script installation") + return + + charm.unit.status = MaintenanceStatus("Installing mail pool synchronization script") + template_context = { + "secondary_hostname": secondary, + "mail_root": MAIL_ROOT, + } + template = charm.jinja.get_template(SYNC_TO_SECONDARY_TEMPLATE) + contents = template.render(template_context) + host.write_file(SYNC_TO_SECONDARY_TARGET, contents, perms=0o755) + + +def setup_mail_sync_cronjob(charm: DovecotCharm) -> None: + """Set up the mail pool synchronization cronjob.""" + if not charm._secondary_hostname: + logger.info("Secondary hostname not yet known; skipping cronjob setup") + return + + charm.unit.status = MaintenanceStatus("Setting up mail pool synchronization cronjob") + template_context = { + "schedule": charm.config.get("sync-schedule", "*/30 * * * *"), + } + template = charm.jinja.get_template(SYNC_TO_SECONDARY_CRONJOB_TEMPLATE) + contents = template.render(template_context) + host.write_file(SYNC_TO_SECONDARY_CRONJOB_TARGET, contents, perms=0o644) + systemd.service_restart("cron") diff --git a/dovecot-charm/tests/unit/test_charm.py b/dovecot-charm/tests/unit/test_charm.py index 2170f3a..6e84d02 100644 --- a/dovecot-charm/tests/unit/test_charm.py +++ b/dovecot-charm/tests/unit/test_charm.py @@ -31,11 +31,11 @@ def reconcile_guards(): patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), - patch("charm.DovecotCharm._setup_ssh_keys"), - patch("charm.DovecotCharm._sync_authorized_keys"), - patch("charm.DovecotCharm._sync_known_hosts"), - patch("charm.DovecotCharm._install_mail_sync_script"), - patch("charm.DovecotCharm._setup_mail_sync_cronjob"), + patch("ha.setup_ssh_keys"), + patch("ha.sync_authorized_keys"), + patch("ha.sync_known_hosts"), + patch("ha.install_mail_sync_script"), + patch("ha.setup_mail_sync_cronjob"), ): yield @@ -147,16 +147,16 @@ def test_reconcile_skips_sync_script_when_not_primary(ctx, base_state): patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), # ssh keygen — real subprocess not under test - patch("charm.DovecotCharm._setup_ssh_keys"), + patch("ha.setup_ssh_keys"), # authorized_keys sync — not under test - patch("charm.DovecotCharm._sync_authorized_keys"), + patch("ha.sync_authorized_keys"), # known_hosts sync — not under test - patch("charm.DovecotCharm._sync_known_hosts"), + patch("ha.sync_known_hosts"), # Override _is_primary to simulate being a non-primary unit patch("charm.DovecotCharm._is_primary", new_callable=PropertyMock, return_value=False), # These should NOT be called — we verify via state not mocks - patch("charm.DovecotCharm._install_mail_sync_script") as mock_sync, - patch("charm.DovecotCharm._setup_mail_sync_cronjob") as mock_cron, + patch("ha.install_mail_sync_script") as mock_sync, + patch("ha.setup_mail_sync_cronjob") as mock_cron, ): state_out = ctx.run(ctx.on.config_changed(), base_state) diff --git a/dovecot-charm/tests/unit/test_storage.py b/dovecot-charm/tests/unit/test_storage.py index 106460f..1a9842d 100644 --- a/dovecot-charm/tests/unit/test_storage.py +++ b/dovecot-charm/tests/unit/test_storage.py @@ -30,11 +30,11 @@ def test_start_uses_saved_dev_path_when_model_error(ctx, base_state): patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), # HA methods do filesystem I/O (ssh-keygen, authorized_keys, sync scripts) - patch("charm.DovecotCharm._setup_ssh_keys"), - patch("charm.DovecotCharm._sync_authorized_keys"), - patch("charm.DovecotCharm._sync_known_hosts"), - patch("charm.DovecotCharm._install_mail_sync_script"), - patch("charm.DovecotCharm._setup_mail_sync_cronjob"), + patch("ha.setup_ssh_keys"), + patch("ha.sync_authorized_keys"), + patch("ha.sync_known_hosts"), + patch("ha.install_mail_sync_script"), + patch("ha.setup_mail_sync_cronjob"), patch("ops._main._Dispatcher.run_any_legacy_hook"), ): state_out = ctx.run(ctx.on.start(), state_in) @@ -79,11 +79,11 @@ def test_storage_attached_luks_auto_provisioning_disabled_mounted_is_active(ctx, patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), # HA methods do filesystem I/O — not under test - patch("charm.DovecotCharm._setup_ssh_keys"), - patch("charm.DovecotCharm._sync_authorized_keys"), - patch("charm.DovecotCharm._sync_known_hosts"), - patch("charm.DovecotCharm._install_mail_sync_script"), - patch("charm.DovecotCharm._setup_mail_sync_cronjob"), + patch("ha.setup_ssh_keys"), + patch("ha.sync_authorized_keys"), + patch("ha.sync_known_hosts"), + patch("ha.install_mail_sync_script"), + patch("ha.setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_attached(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) @@ -118,11 +118,11 @@ def test_storage_attached_calls_setup_luks_with_key(ctx, base_state): patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), # HA methods do filesystem I/O — not under test - patch("charm.DovecotCharm._setup_ssh_keys"), - patch("charm.DovecotCharm._sync_authorized_keys"), - patch("charm.DovecotCharm._sync_known_hosts"), - patch("charm.DovecotCharm._install_mail_sync_script"), - patch("charm.DovecotCharm._setup_mail_sync_cronjob"), + patch("ha.setup_ssh_keys"), + patch("ha.sync_authorized_keys"), + patch("ha.sync_known_hosts"), + patch("ha.install_mail_sync_script"), + patch("ha.setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_attached(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) @@ -144,11 +144,11 @@ def test_storage_attached_saves_dev_path(ctx, base_state): patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), # HA methods do filesystem I/O — not under test - patch("charm.DovecotCharm._setup_ssh_keys"), - patch("charm.DovecotCharm._sync_authorized_keys"), - patch("charm.DovecotCharm._sync_known_hosts"), - patch("charm.DovecotCharm._install_mail_sync_script"), - patch("charm.DovecotCharm._setup_mail_sync_cronjob"), + patch("ha.setup_ssh_keys"), + patch("ha.sync_authorized_keys"), + patch("ha.sync_known_hosts"), + patch("ha.install_mail_sync_script"), + patch("ha.setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_attached(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) @@ -209,11 +209,11 @@ def test_storage_detaching_unmount_and_close(ctx, base_state): patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), # HA methods do filesystem I/O — not under test - patch("charm.DovecotCharm._setup_ssh_keys"), - patch("charm.DovecotCharm._sync_authorized_keys"), - patch("charm.DovecotCharm._sync_known_hosts"), - patch("charm.DovecotCharm._install_mail_sync_script"), - patch("charm.DovecotCharm._setup_mail_sync_cronjob"), + patch("ha.setup_ssh_keys"), + patch("ha.sync_authorized_keys"), + patch("ha.sync_known_hosts"), + patch("ha.install_mail_sync_script"), + patch("ha.setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_detaching(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) @@ -250,11 +250,11 @@ def test_storage_detaching_luks_disabled_skips_close(ctx, base_state): patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), # HA methods do filesystem I/O — not under test - patch("charm.DovecotCharm._setup_ssh_keys"), - patch("charm.DovecotCharm._sync_authorized_keys"), - patch("charm.DovecotCharm._sync_known_hosts"), - patch("charm.DovecotCharm._install_mail_sync_script"), - patch("charm.DovecotCharm._setup_mail_sync_cronjob"), + patch("ha.setup_ssh_keys"), + patch("ha.sync_authorized_keys"), + patch("ha.sync_known_hosts"), + patch("ha.install_mail_sync_script"), + patch("ha.setup_mail_sync_cronjob"), ): state_out = ctx.run(ctx.on.storage_detaching(storage), state_in) assert isinstance(state_out.unit_status, ops.ActiveStatus) diff --git a/dovecot-charm/tests/unit/test_tls.py b/dovecot-charm/tests/unit/test_tls.py index 86d95aa..6a50d0f 100644 --- a/dovecot-charm/tests/unit/test_tls.py +++ b/dovecot-charm/tests/unit/test_tls.py @@ -47,7 +47,7 @@ def test_setup_tls_writes_cert_key_and_chain(ctx, base_state, tmp_path): patch("charm.shutil.which", return_value="/usr/bin/doveconf"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), - patch("charm.DovecotCharm._setup_ssh_keys"), + patch("ha.setup_ssh_keys"), ctx(ctx.on.config_changed(), base_state) as mgr, ): # Override the TLS library instance so get_assigned_certificate @@ -86,7 +86,7 @@ def test_setup_tls_no_ca_omits_chain(ctx, base_state, tmp_path): patch("charm.shutil.which", return_value="/usr/bin/doveconf"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), - patch("charm.DovecotCharm._setup_ssh_keys"), + patch("ha.setup_ssh_keys"), ctx(ctx.on.config_changed(), base_state) as mgr, ): mgr.charm._tls = MagicMock() @@ -143,7 +143,7 @@ def test_certificate_available_event_triggers_reconcile(ctx, base_state, tmp_pat ), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), - patch("charm.DovecotCharm._setup_ssh_keys"), + patch("ha.setup_ssh_keys"), ): # Fire certificate_available via config_changed (same handler) state_out = ctx.run(ctx.on.config_changed(), base_state) From 700cd907e9f3c0bf3d0225dd2f7556d368445e44 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Tue, 21 Apr 2026 15:50:38 +0300 Subject: [PATCH 31/39] refactor(tests): increase timeout for unit activation in HA and storage tests --- docs/release-notes/artifacts/pr-4-ha.yaml | 2 +- docs/release-notes/release-notes-0005.rst | 2 +- dovecot-charm/src/charm.py | 31 +++++--- dovecot-charm/src/exceptions.py | 6 ++ dovecot-charm/src/ha.py | 75 ++++++++++++------- .../templates/sync-to-secondary.sh.tmpl | 2 +- dovecot-charm/tests/integration/test_ha.py | 9 ++- dovecot-charm/tests/integration/test_mail.py | 2 +- .../tests/integration/test_storage.py | 6 +- dovecot-charm/tests/unit/test_charm.py | 66 ++++++++++------ 10 files changed, 134 insertions(+), 67 deletions(-) diff --git a/docs/release-notes/artifacts/pr-4-ha.yaml b/docs/release-notes/artifacts/pr-4-ha.yaml index 0b7b284..c20101b 100644 --- a/docs/release-notes/artifacts/pr-4-ha.yaml +++ b/docs/release-notes/artifacts/pr-4-ha.yaml @@ -8,7 +8,7 @@ changes: - title: Added HA support with SSH key exchange and force-sync action author: alithethird type: major - description: Added high availability support for Dovecot with automatic SSH key exchange between primary and secondary units via the replicas peer relation, rsync-based mail synchronization via cron, and a force-sync Juju action for on-demand replication. + description: Added high availability support for Dovecot with automatic SSH key exchange between primary and secondary units via the replicas peer relation, doveadm backup (dsync)-based mail synchronization via cron, and a force-sync Juju action for on-demand replication. urls: pr: - "https://github.com/canonical/mailserver-operators/pull/15" diff --git a/docs/release-notes/release-notes-0005.rst b/docs/release-notes/release-notes-0005.rst index 9a5ac16..a09e423 100644 --- a/docs/release-notes/release-notes-0005.rst +++ b/docs/release-notes/release-notes-0005.rst @@ -44,7 +44,7 @@ secondary unit on demand. Relevant links: -* `PR `_ +* `PR `_ Bug fixes --------- diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index d67eda8..ba54b98 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -39,7 +39,7 @@ TLS_CERT_DIR, ) from dovecot_config import DovecotConfig, DovecotConfigInvalidError, DovecotConfigSecretError -from exceptions import CharmBlockedError, ConfigurationError +from exceptions import CharmBlockedError, ConfigurationError, HASetupError from storage import ensure_storage_ready, teardown_detaching_storage logger = logging.getLogger(__name__) @@ -58,7 +58,7 @@ def __init__(self, *args): self.framework.observe(self.on.clear_queue_action, self._on_clear_queue_action) self.framework.observe(self.on.mail_data_storage_attached, self._reconcile) self.framework.observe(self.on.mail_data_storage_detaching, self._reconcile) - self.framework.observe(self.on.replicas_relation_changed, self._reconcile) + self.framework.observe(self.on[PEER_RELATION_NAME].relation_changed, self._reconcile) self.framework.observe(self.on.force_sync_action, self._on_force_sync) self.framework.observe( @@ -173,12 +173,16 @@ def _reconcile(self, event): except ConfigurationError as e: self.unit.status = BlockedStatus(str(e)) return - ha.setup_ssh_keys(self) - ha.sync_authorized_keys(self) - ha.sync_known_hosts(self) - if self._is_primary: - ha.install_mail_sync_script(self) - ha.setup_mail_sync_cronjob(self) + try: + ha.setup_ssh_keys(self) + ha.sync_authorized_keys(self) + ha.sync_known_hosts(self) + if self._is_primary: + ha.install_mail_sync_script(self) + ha.setup_mail_sync_cronjob(self) + except HASetupError as e: + self.unit.status = BlockedStatus(str(e)) + return self._open_ports() self.unit.status = ops.ActiveStatus() @@ -306,13 +310,20 @@ def _on_force_sync(self, event): event.fail("No secondary unit found to sync to.") return + if not Path(SYNC_TO_SECONDARY_TARGET).exists(): + event.fail( + "Sync script not yet installed. " + "Please wait for the charm to reach active state before running force-sync." + ) + return + try: cmd = [SYNC_TO_SECONDARY_TARGET] logger.info(f"Running manual sync: {' '.join(cmd)}") subprocess.run(cmd, check=True, capture_output=True, text=True) event.set_results({"result": "Sync completed successfully"}) - except subprocess.CalledProcessError as e: - msg = f"Sync failed: {e.stderr}" + except (subprocess.CalledProcessError, FileNotFoundError) as e: + msg = f"Sync failed: {e}" logger.error(msg) event.fail(msg) diff --git a/dovecot-charm/src/exceptions.py b/dovecot-charm/src/exceptions.py index f3e9f1c..a2baaf3 100644 --- a/dovecot-charm/src/exceptions.py +++ b/dovecot-charm/src/exceptions.py @@ -29,3 +29,9 @@ class ConfigurationError(CharmBlockedError): """Raised when charm or service configuration is invalid or fails.""" pass + + +class HASetupError(CharmBlockedError): + """Raised when HA setup (SSH keys, sync scripts, sshd config, etc.) fails.""" + + pass diff --git a/dovecot-charm/src/ha.py b/dovecot-charm/src/ha.py index 6d6f378..79fd333 100644 --- a/dovecot-charm/src/ha.py +++ b/dovecot-charm/src/ha.py @@ -9,6 +9,7 @@ import socket import subprocess # nosec import typing +from pathlib import Path from charmhelpers.core import host from charmlibs import systemd @@ -25,6 +26,7 @@ SYNC_TO_SECONDARY_TARGET, SYNC_TO_SECONDARY_TEMPLATE, ) +from exceptions import HASetupError if typing.TYPE_CHECKING: from charm import DovecotCharm @@ -38,21 +40,26 @@ def setup_ssh_keys(charm: DovecotCharm) -> None: Publishes both the user public key (for authorized_keys) and the host public key (for known_hosts) so peers can verify each other's identity without disabling StrictHostKeyChecking. + + Raises: + HASetupError: If SSH key generation fails. """ SSH_DIR.mkdir(mode=0o700, exist_ok=True) key_file = SSH_DIR / "id_ed25519" if not key_file.exists(): - subprocess.run( - ["/usr/bin/ssh-keygen", "-t", "ed25519", "-N", "", "-f", str(key_file)], - check=True, - capture_output=True, - ) + try: + subprocess.run( + ["/usr/bin/ssh-keygen", "-t", "ed25519", "-N", "", "-f", str(key_file)], + check=True, + capture_output=True, + ) + except subprocess.CalledProcessError as e: + raise HASetupError(f"SSH key generation failed: {e.stderr}") from e pub_key_file = SSH_DIR / "id_ed25519.pub" if not pub_key_file.exists(): - logger.error("SSH public key file not found after key generation") - return + raise HASetupError("SSH public key file not found after key generation") pub_key = pub_key_file.read_text().strip() relation = charm.model.get_relation(PEER_RELATION_NAME) @@ -118,23 +125,32 @@ def sync_known_hosts(charm: DovecotCharm) -> None: def ensure_root_ssh_login() -> None: - """Set PermitRootLogin to prohibit-password in sshd_config and reload sshd.""" - if SSHD_CONFIG.exists(): - content = SSHD_CONFIG.read_text() - new_content = "" - found = False - for line in content.splitlines(keepends=True): - stripped = line.lstrip("#").strip() - if stripped.startswith("PermitRootLogin"): - new_content += "PermitRootLogin prohibit-password\n" - found = True - else: - new_content += line - if not found: - new_content += "\nPermitRootLogin prohibit-password\n" - if new_content != content: - SSHD_CONFIG.write_text(new_content) + """Set PermitRootLogin to prohibit-password in sshd_config and reload sshd. + + Raises: + HASetupError: If sshd reload fails. + """ + if not SSHD_CONFIG.exists(): + return + + content = SSHD_CONFIG.read_text() + new_content = "" + found = False + for line in content.splitlines(keepends=True): + stripped = line.lstrip("#").strip() + if stripped.startswith("PermitRootLogin"): + new_content += "PermitRootLogin prohibit-password\n" + found = True + else: + new_content += line + if not found: + new_content += "\nPermitRootLogin prohibit-password\n" + if new_content != content: + SSHD_CONFIG.write_text(new_content) + try: systemd.service_reload("ssh", restart_on_failure=True) + except subprocess.CalledProcessError as e: + raise HASetupError(f"Failed to reload sshd after config change: {e}") from e def install_mail_sync_script(charm: DovecotCharm) -> None: @@ -158,7 +174,12 @@ def install_mail_sync_script(charm: DovecotCharm) -> None: def setup_mail_sync_cronjob(charm: DovecotCharm) -> None: - """Set up the mail pool synchronization cronjob.""" + """Set up the mail pool synchronization cronjob. + + Skips writing and does not restart cron if the file content is unchanged. + Cron on Ubuntu automatically picks up changes in /etc/cron.d, so no + service restart is needed when the file is updated. + """ if not charm._secondary_hostname: logger.info("Secondary hostname not yet known; skipping cronjob setup") return @@ -169,5 +190,9 @@ def setup_mail_sync_cronjob(charm: DovecotCharm) -> None: } template = charm.jinja.get_template(SYNC_TO_SECONDARY_CRONJOB_TEMPLATE) contents = template.render(template_context) + + cronjob_path = Path(SYNC_TO_SECONDARY_CRONJOB_TARGET) + if cronjob_path.exists() and cronjob_path.read_text() == contents: + return + host.write_file(SYNC_TO_SECONDARY_CRONJOB_TARGET, contents, perms=0o644) - systemd.service_restart("cron") diff --git a/dovecot-charm/templates/sync-to-secondary.sh.tmpl b/dovecot-charm/templates/sync-to-secondary.sh.tmpl index 126ad02..41e0914 100644 --- a/dovecot-charm/templates/sync-to-secondary.sh.tmpl +++ b/dovecot-charm/templates/sync-to-secondary.sh.tmpl @@ -15,6 +15,6 @@ for user_dir in "{{ mail_root }}"/*; do done if [ "$found" -eq 0 ]; then echo "No Maildir found under {{ mail_root }}; nothing to sync." >&2 - exit 1 + exit 0 fi touch {{ mail_root }}/.last-dsync diff --git a/dovecot-charm/tests/integration/test_ha.py b/dovecot-charm/tests/integration/test_ha.py index 5211df9..6b92432 100644 --- a/dovecot-charm/tests/integration/test_ha.py +++ b/dovecot-charm/tests/integration/test_ha.py @@ -14,8 +14,9 @@ def _get_unit_hostname(status, app_name, unit_name): machine = status.apps[app_name].units[unit_name].machine return status.machines[machine].hostname except KeyError: - logging.error(f"Unit {unit_name} not found in status.") - return None + message = f"Could not determine hostname for unit {unit_name} from Juju status." + logging.error(message) + pytest.fail(message) @pytest.mark.timeout(1800) @@ -34,7 +35,7 @@ def two_units_active(status): return jubilant.all_active(status) logging.info("Waiting for 2 units to be active...") - juju.wait(two_units_active, timeout=600) + juju.wait(two_units_active, timeout=10 * 60) status = juju.status() units = list(status.apps[dovecot_charm].units.keys()) @@ -46,7 +47,7 @@ def two_units_active(status): logging.info(f"Primary: {primary}, Secondary: {secondary}") juju.config(dovecot_charm, {"primary-unit": primary}) - juju.wait(jubilant.all_active, timeout=300) + juju.wait(jubilant.all_active, timeout=5 * 60) logging.info("Verifying SSH key exchange...") diff --git a/dovecot-charm/tests/integration/test_mail.py b/dovecot-charm/tests/integration/test_mail.py index 3d063ff..315772e 100644 --- a/dovecot-charm/tests/integration/test_mail.py +++ b/dovecot-charm/tests/integration/test_mail.py @@ -17,7 +17,7 @@ def test_mail_workflow(juju: jubilant.Juju, dovecot_charm: str): unit_name = f"{dovecot_charm}/0" logging.info(f"Updating primary-unit config to {unit_name}...") juju.config(dovecot_charm, {"primary-unit": unit_name}) - juju.wait(jubilant.all_active, timeout=300) + juju.wait(jubilant.all_active, timeout=5 * 60) password = token_hex(8) logging.info("Configuring user 'ubuntu'...") diff --git a/dovecot-charm/tests/integration/test_storage.py b/dovecot-charm/tests/integration/test_storage.py index 9203398..de51ee8 100644 --- a/dovecot-charm/tests/integration/test_storage.py +++ b/dovecot-charm/tests/integration/test_storage.py @@ -17,7 +17,7 @@ def test_luks_storage_auto_provisioning(juju: jubilant.Juju, dovecot_charm: str) logging.info(f"Targeting unit: {unit_name}") logging.info("Waiting for charm to be active with storage attached...") - juju.wait(jubilant.all_active, timeout=600) + juju.wait(jubilant.all_active, timeout=10 * 60) logging.info("Verifying LUKS setup...") juju.exec("ls -l /dev/mapper/mail-data", unit=unit_name) @@ -126,7 +126,7 @@ def test_luks_storage_manual_provisioning(juju: jubilant.Juju, dovecot_charm_man juju.config(dovecot_charm_manual_storage, {"mailname": "example1.com"}) logging.info("Waiting for charm to become active...") - juju.wait(jubilant.all_active, timeout=300) + juju.wait(jubilant.all_active, timeout=5 * 60) # Verify LUKS device status logging.info("Verifying LUKS device is properly configured...") @@ -182,7 +182,7 @@ def test_data_persists_across_restart(juju: jubilant.Juju, dovecot_charm: str): # Wait for charm to re-settle after reboot logging.info("Waiting for charm to re-settle...") - juju.wait(jubilant.all_active, timeout=600) + juju.wait(jubilant.all_active, timeout=10 * 60) # After reboot the Juju storage API may not yet be re-provisioned when the # start hook fires; the charm defers and retries until LUKS open + mount diff --git a/dovecot-charm/tests/unit/test_charm.py b/dovecot-charm/tests/unit/test_charm.py index 6e84d02..ed9393e 100644 --- a/dovecot-charm/tests/unit/test_charm.py +++ b/dovecot-charm/tests/unit/test_charm.py @@ -10,7 +10,7 @@ import pytest from charm import DovecotCharm -from exceptions import ConfigurationError +from exceptions import ConfigurationError, HASetupError # --------------------------------------------------------------------------- # Helpers — patches shared across many tests @@ -85,7 +85,6 @@ def test_reconcile_blocks_when_procmail_setup_fails(ctx, base_state): patch("charm.shutil.which", return_value="/usr/bin/doveconf"), patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), - # _setup_procmail raises — this is the condition under test patch( "charm.DovecotCharm._setup_procmail", side_effect=ConfigurationError("Failed to configure postfix: error"), @@ -104,7 +103,6 @@ def test_reconcile_blocks_when_procmail_setup_fails(ctx, base_state): def test_is_primary_true_when_unit_matches_config(ctx, base_state): """_is_primary returns True when primary-unit config matches this unit.""" - # base_state has primary-unit=dovecot-charm/0; the ctx app_name gives unit dovecot-charm/0 with reconcile_guards(), ctx(ctx.on.config_changed(), base_state) as mgr: assert mgr.charm._is_primary is True @@ -120,7 +118,6 @@ def test_is_primary_false_when_unit_differs(ctx, base_state): base_state, config={**base_state.config, "primary-unit": "dovecot-charm/99"} ) with ( - # config validation rejects unknown units — bypass it since we're only testing _is_primary patch("charm.DovecotCharm._get_dovecot_config"), patch("charm.ensure_storage_ready"), patch("charm.teardown_detaching_storage"), @@ -137,8 +134,6 @@ def test_is_primary_false_when_unit_differs(ctx, base_state): def test_reconcile_skips_sync_script_when_not_primary(ctx, base_state): """When this unit is NOT primary, sync script and cronjob are not installed.""" - # Use a valid config but override _is_primary to False to bypass pydantic - # validation (which requires primary-unit to match an existing unit). with ( patch("charm.ensure_storage_ready"), patch("charm.teardown_detaching_storage"), @@ -146,23 +141,16 @@ def test_reconcile_skips_sync_script_when_not_primary(ctx, base_state): patch("charm.DovecotCharm._setup_tls"), patch("charm.DovecotCharm._setup_dovecot"), patch("charm.DovecotCharm._setup_procmail"), - # ssh keygen — real subprocess not under test patch("ha.setup_ssh_keys"), - # authorized_keys sync — not under test patch("ha.sync_authorized_keys"), - # known_hosts sync — not under test patch("ha.sync_known_hosts"), - # Override _is_primary to simulate being a non-primary unit patch("charm.DovecotCharm._is_primary", new_callable=PropertyMock, return_value=False), - # These should NOT be called — we verify via state not mocks patch("ha.install_mail_sync_script") as mock_sync, patch("ha.setup_mail_sync_cronjob") as mock_cron, ): state_out = ctx.run(ctx.on.config_changed(), base_state) - # Charm still reaches Active even without sync scripts assert isinstance(state_out.unit_status, ops.ActiveStatus) - # Secondary check: these should not have been called since unit is not primary mock_sync.assert_not_called() mock_cron.assert_not_called() @@ -176,7 +164,6 @@ def test_clear_queue_deferred(ctx, base_state): """clear-queue action with queue=deferred passes correct args to postsuper.""" mock_result = MagicMock(stdout="cleared") with ( - # postsuper is the only subprocess call in this action path patch("charm.subprocess.run", return_value=mock_result) as mock_run, ): ctx.run( @@ -196,7 +183,6 @@ def test_clear_queue_all(ctx, base_state): """clear-queue action with queue=all omits the deferred queue filter.""" mock_result = MagicMock(stdout="cleared") with ( - # postsuper is the only subprocess call in this action path patch("charm.subprocess.run", return_value=mock_result) as mock_run, ): ctx.run( @@ -215,7 +201,6 @@ def test_clear_queue_all(ctx, base_state): def test_clear_queue_failure(ctx, base_state): """clear-queue action must fail when postsuper returns non-zero.""" with ( - # simulate postsuper failure patch( "charm.subprocess.run", side_effect=CalledProcessError(1, "postsuper", stderr="error msg"), @@ -238,9 +223,8 @@ def test_force_sync_success(ctx, base_state): """force-sync succeeds when this unit is primary and a secondary exists.""" mock_result = MagicMock(stdout="ok", stderr="") with ( - # sync script subprocess call — the action delegates to the shell script patch("charm.subprocess.run", return_value=mock_result), - # provide a secondary hostname so the action doesn't bail out + patch("charm.Path") as mock_path_cls, patch.object( DovecotCharm, "_secondary_hostname", @@ -248,13 +232,13 @@ def test_force_sync_success(ctx, base_state): return_value="10.0.0.2", ), ): + mock_path_cls.return_value.exists.return_value = True ctx.run(ctx.on.action("force-sync"), base_state) assert ctx.action_results == {"result": "Sync completed successfully"} def test_force_sync_not_primary(ctx, base_state): """force-sync must fail when executed on a non-primary unit.""" - # Override _is_primary since pydantic rejects unknown unit names with ( patch("charm.DovecotCharm._is_primary", new_callable=PropertyMock, return_value=False), pytest.raises(ops.testing.ActionFailed) as exc_info, @@ -273,12 +257,11 @@ def test_force_sync_no_secondary(ctx, base_state): def test_force_sync_subprocess_failure(ctx, base_state): """force-sync must fail when the sync script exits non-zero.""" with ( - # sync script fails patch( "charm.subprocess.run", side_effect=CalledProcessError(1, "sync", stderr="fail"), ), - # provide secondary so the action reaches subprocess.run + patch("charm.Path") as mock_path_cls, patch.object( DovecotCharm, "_secondary_hostname", @@ -287,5 +270,46 @@ def test_force_sync_subprocess_failure(ctx, base_state): ), pytest.raises(ops.testing.ActionFailed) as exc_info, ): + mock_path_cls.return_value.exists.return_value = True ctx.run(ctx.on.action("force-sync"), base_state) assert "fail" in exc_info.value.message + + +# --------------------------------------------------------------------------- +# HA: reconcile blocks on HASetupError +# --------------------------------------------------------------------------- + + +def test_reconcile_blocks_when_ha_setup_fails(ctx, base_state): + """Charm must be Blocked when HA setup raises HASetupError.""" + with ( + patch("charm.ensure_storage_ready"), + patch("charm.teardown_detaching_storage"), + patch("charm.shutil.which", return_value="/usr/bin/doveconf"), + patch("charm.DovecotCharm._setup_tls"), + patch("charm.DovecotCharm._setup_dovecot"), + patch("charm.DovecotCharm._setup_procmail"), + patch("ha.setup_ssh_keys", side_effect=HASetupError("SSH keygen failed")), + ): + state_out = ctx.run(ctx.on.config_changed(), base_state) + + assert isinstance(state_out.unit_status, ops.BlockedStatus) + assert "SSH keygen failed" in state_out.unit_status.message + + +def test_force_sync_script_not_installed(ctx, base_state): + """force-sync must fail with a clear message when sync script is not yet installed.""" + with ( + patch.object( + DovecotCharm, + "_secondary_hostname", + new_callable=PropertyMock, + return_value="10.0.0.2", + ), + patch("charm.Path") as mock_path_cls, + pytest.raises(ops.testing.ActionFailed) as exc_info, + ): + mock_path_cls.return_value.exists.return_value = False + ctx.run(ctx.on.action("force-sync"), base_state) + + assert "wait for the charm" in exc_info.value.message From 28beadecb86354ad975bc2ccbb6ff7299526217f Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Wed, 22 Apr 2026 09:52:01 +0300 Subject: [PATCH 32/39] refactor(ha): enhance SSH key generation error handling and update cron template logging --- dovecot-charm/src/charm.py | 16 +- dovecot-charm/src/ha.py | 1 + .../templates/sync-to-secondary_cron.tmpl | 4 +- dovecot-charm/tests/integration/test_ha.py | 264 +++++++++++++----- 4 files changed, 214 insertions(+), 71 deletions(-) diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index ba54b98..d5133ca 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -275,8 +275,6 @@ def _setup_procmail(self) -> None: logger.exception(f"Failed to configure postfix: {e}") raise ConfigurationError(f"Failed to configure postfix: {e.stderr}") from e - # -- Actions -------------------------------------------------------------- - def _on_clear_queue_action(self, event): """Handle the clear-queue action.""" queue_to_clear = event.params.get("queue", "deferred") @@ -322,7 +320,19 @@ def _on_force_sync(self, event): logger.info(f"Running manual sync: {' '.join(cmd)}") subprocess.run(cmd, check=True, capture_output=True, text=True) event.set_results({"result": "Sync completed successfully"}) - except (subprocess.CalledProcessError, FileNotFoundError) as e: + except subprocess.CalledProcessError as e: + parts = [ + f"Sync failed with exit code {e.returncode} while running " + f"{' '.join(e.cmd) if isinstance(e.cmd, (list, tuple)) else e.cmd}" + ] + if e.stderr and e.stderr.strip(): + parts.append(f"stderr: {e.stderr.strip()}") + if e.stdout and e.stdout.strip(): + parts.append(f"stdout: {e.stdout.strip()}") + msg = ". ".join(parts) + logger.error(msg) + event.fail(msg) + except FileNotFoundError as e: msg = f"Sync failed: {e}" logger.error(msg) event.fail(msg) diff --git a/dovecot-charm/src/ha.py b/dovecot-charm/src/ha.py index 79fd333..fb0abae 100644 --- a/dovecot-charm/src/ha.py +++ b/dovecot-charm/src/ha.py @@ -53,6 +53,7 @@ def setup_ssh_keys(charm: DovecotCharm) -> None: ["/usr/bin/ssh-keygen", "-t", "ed25519", "-N", "", "-f", str(key_file)], check=True, capture_output=True, + text=True, ) except subprocess.CalledProcessError as e: raise HASetupError(f"SSH key generation failed: {e.stderr}") from e diff --git a/dovecot-charm/templates/sync-to-secondary_cron.tmpl b/dovecot-charm/templates/sync-to-secondary_cron.tmpl index d3101cb..cd2e1a2 100644 --- a/dovecot-charm/templates/sync-to-secondary_cron.tmpl +++ b/dovecot-charm/templates/sync-to-secondary_cron.tmpl @@ -1,3 +1 @@ -{{ schedule }} root /usr/local/bin/sync-to-secondary.sh >> /var/log/sync-to-secondary.log 2>&1 - -# End of file +{{ schedule }} root /usr/local/bin/sync-to-secondary.sh 2>&1 | /usr/bin/logger -t sync-to-secondary diff --git a/dovecot-charm/tests/integration/test_ha.py b/dovecot-charm/tests/integration/test_ha.py index 6b92432..b0fcd60 100644 --- a/dovecot-charm/tests/integration/test_ha.py +++ b/dovecot-charm/tests/integration/test_ha.py @@ -1,7 +1,12 @@ # Copyright 2026 Canonical Ltd. # See LICENSE file for licensing details. +import contextlib +import imaplib import logging +import ssl +import time +from secrets import token_hex from typing import cast import jubilant @@ -19,8 +24,112 @@ def _get_unit_hostname(status, app_name, unit_name): pytest.fail(message) -@pytest.mark.timeout(1800) -def test_ha_failover(juju, dovecot_charm): +def _check_mail_via_imap(unit_ip: str, user: str, password: str, subject: str) -> bool: + """Poll IMAP on unit_ip until the email with the given subject is found.""" + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + for attempt in range(20): + mail = None + try: + mail = imaplib.IMAP4_SSL(unit_ip, port=993, ssl_context=context) + mail.login(user, password) + mail.select("inbox") + _, data = mail.search(None, f'(HEADER Subject "{subject}")') + if data and data[0]: + logging.info(f"Email found via IMAP on {unit_ip}. IDs: {data[0]}") + return True + logging.info(f"Email not found yet on {unit_ip} (attempt {attempt + 1})...") + except (imaplib.IMAP4.error, OSError) as e: + logging.warning(f"IMAP attempt {attempt + 1} on {unit_ip} failed: {e}. Retrying...") + finally: + if mail is not None: + with contextlib.suppress(imaplib.IMAP4.error, OSError): + mail.close() + with contextlib.suppress(imaplib.IMAP4.error, OSError): + mail.logout() + time.sleep(3) + + return False + + +def _setup_mail_user(juju: jubilant.Juju, units: list[str], user: str, password: str): + """Create a system user with a Maildir on each unit, set password on all units.""" + for unit in units: + juju.exec( + ( + f"id -u {user} >/dev/null 2>&1 || " + f"useradd -M -d /srv/mail/{user} -s /usr/sbin/nologin {user}" + ), + unit=unit, + ) + juju.exec(f"echo '{user}:{password}' | chpasswd", unit=unit) + juju.exec(f"usermod -aG mail {user}", unit=unit) + + # Maildir only needs to exist on primary so doveadm backup has something to sync + primary = units[0] + juju.exec( + ( + f"mkdir -p /srv/mail/{user}/Maildir/{{new,cur,tmp}} && " + f"chown -R {user}:{user} /srv/mail/{user} && " + f"chmod 700 /srv/mail/{user} /srv/mail/{user}/Maildir" + ), + unit=primary, + ) + + +def _get_last_sync_mtime(juju: jubilant.Juju, unit: str) -> int | None: + """Return /srv/mail/.last-dsync mtime epoch on unit, or None if missing.""" + output = juju.exec( + "stat -c %Y /srv/mail/.last-dsync 2>/dev/null || true", unit=unit + ).stdout.strip() + return int(output) if output.isdigit() else None + + +def _get_sync_cron_run_count(juju: jubilant.Juju, unit: str) -> int: + """Return count of sync-to-secondary cron executions recorded in syslog.""" + output = juju.exec( + "grep -c 'sync-to-secondary.sh' /var/log/syslog 2>/dev/null || true", + unit=unit, + ).stdout.strip() + return int(output) if output.isdigit() else 0 + + +def _wait_for_sync_trigger( + juju: jubilant.Juju, + unit: str, + previous_mtime: int | None, + previous_cron_count: int, + timeout: int = 4 * 60, + poll_interval: int = 5, +) -> int: + """Wait for a cron-triggered sync signal and return observed marker mtime. + + Accepts either /srv/mail/.last-dsync mtime advance or syslog cron count increase + to work across charm revisions with different script logging/exit behavior. + """ + deadline = time.time() + timeout + while time.time() < deadline: + current_mtime = _get_last_sync_mtime(juju, unit) + if current_mtime is not None and (previous_mtime is None or current_mtime > previous_mtime): + return current_mtime + + current_cron_count = _get_sync_cron_run_count(juju, unit) + if current_cron_count > previous_cron_count: + return current_mtime or 0 + + time.sleep(poll_interval) + + raise AssertionError( + "Timed out waiting for sync trigger on " + f"{unit}; previous mtime={previous_mtime}, previous cron count={previous_cron_count}" + ) + + +@pytest.mark.timeout(30 * 60) +def test_force_sync_action(juju: jubilant.Juju, dovecot_charm: str): + """force-sync action replicates mail from primary to secondary via doveadm backup.""" status = juju.status() if len(status.apps[dovecot_charm].units) < 2: logging.info("Adding the second unit...") @@ -28,9 +137,7 @@ def test_ha_failover(juju, dovecot_charm): def two_units_active(status): app = status.apps.get(dovecot_charm) - if not app: - return False - if len(app.units) < 2: + if not app or len(app.units) < 2: return False return jubilant.all_active(status) @@ -38,82 +145,109 @@ def two_units_active(status): juju.wait(two_units_active, timeout=10 * 60) status = juju.status() - units = list(status.apps[dovecot_charm].units.keys()) - units.sort(key=lambda x: int(x.split("/")[-1])) - - primary = units[0] - secondary = units[1] - + units = sorted(status.apps[dovecot_charm].units.keys(), key=lambda x: int(x.split("/")[-1])) + primary, secondary = units[0], units[1] logging.info(f"Primary: {primary}, Secondary: {secondary}") juju.config(dovecot_charm, {"primary-unit": primary}) juju.wait(jubilant.all_active, timeout=5 * 60) - logging.info("Verifying SSH key exchange...") + # Remove legacy HA test users that can break dsync on reruns. + for unit in (primary, secondary): + juju.exec("rm -rf /srv/mail/syncuser* /srv/mail/autosyncuser*", unit=unit) - cmd = "cat /root/.ssh/authorized_keys | wc -l" + # Set up test user on both units (PAM auth requires user to exist on secondary for IMAP) + user = f"syncuser{token_hex(3)}" + password = token_hex(8) + for unit in (primary, secondary): + juju.exec(f"rm -rf /srv/mail/{user}", unit=unit) + _setup_mail_user(juju, [primary, secondary], user, password) - result_primary = juju.exec(cmd, unit=primary) - logging.info(f"Primary authorized_keys count: {result_primary.stdout.strip()}") - assert int(result_primary.stdout.strip()) >= 1 + # Send email on primary + subject = f"Force Sync Test {token_hex(4)}" + logging.info(f"Sending test email on primary with subject: {subject}") + juju.exec(f"echo 'test body' | mail -s '{subject}' {user}@localhost", unit=primary) - result_secondary = juju.exec(cmd, unit=secondary) - logging.info(f"Secondary authorized_keys count: {result_secondary.stdout.strip()}") - assert int(result_secondary.stdout.strip()) >= 1 + # Run force-sync on primary + logging.info("Running force-sync action on primary...") + task = juju.run(unit=primary, action="force-sync", wait=2 * 60) + assert task.status == "completed" + assert task.results["result"] == "Sync completed successfully" - logging.info("Verifying sync script on Primary...") + # Verify email arrived on secondary via IMAP + secondary_ip = juju.status().apps[dovecot_charm].units[secondary].public_address + logging.info(f"Checking for email on secondary via IMAP at {secondary_ip}:993...") + assert _check_mail_via_imap(secondary_ip, user, password, subject), ( + f"Email with subject '{subject}' not found on secondary after force-sync" + ) - status = juju.status() - secondary_hostname = _get_unit_hostname(status, dovecot_charm, secondary) - logging.info(f"Secondary hostname: {secondary_hostname}") + # force-sync must fail on secondary + with pytest.raises(jubilant.TaskError) as exc_info: + juju.run(unit=secondary, action="force-sync", wait=2 * 60) + assert cast(jubilant.TaskError, exc_info.value).task.status == "failed" + logging.info("force-sync on secondary correctly failed.") - script_path = "/usr/local/bin/sync-to-secondary.sh" - cmd = f"cat {script_path}" - script_content = juju.exec(cmd, unit=primary).stdout - logging.info(f"Sync script content on Primary:\n{script_content}") - assert secondary_hostname in script_content, ( - "Secondary hostname not found in sync script on Primary" - ) +@pytest.mark.timeout(10 * 60) +def test_auto_sync(juju: jubilant.Juju, dovecot_charm: str): + """Auto-sync via cron replicates mail from primary to secondary within 2 minutes.""" + status = juju.status() + units = sorted(status.apps[dovecot_charm].units.keys(), key=lambda x: int(x.split("/")[-1])) + assert len(units) >= 2, "Need at least 2 units; run test_force_sync_action first" + primary, secondary = units[0], units[1] - logging.info("Running force-sync on Primary...") + logging.info(f"Ensuring primary-unit is set to {primary}...") + juju.config(dovecot_charm, {"primary-unit": primary}) + juju.wait(jubilant.all_active, timeout=5 * 60) - # Ensure a real system user exists for doveadm user lookup. - # A bare /srv/mail/ directory is not enough for dsync. - sync_user = "syncuser" + # Remove legacy HA test users that can break dsync on reruns. for unit in (primary, secondary): - juju.exec("rm -rf /srv/mail/syncuser /srv/mail/sync-* /srv/mail/testuser", unit=unit) + juju.exec("rm -rf /srv/mail/syncuser* /srv/mail/autosyncuser*", unit=unit) - juju.exec( - ( - f"id -u {sync_user} >/dev/null 2>&1 || " - f"useradd -M -d /srv/mail/{sync_user} -s /usr/sbin/nologin {sync_user}" - ), - unit=primary, - ) - juju.exec( - ( - f"mkdir -p /srv/mail/{sync_user}/Maildir/{{new,cur,tmp}} && " - f"chown -R {sync_user}:{sync_user} /srv/mail/{sync_user} && " - f"chmod 700 /srv/mail/{sync_user} /srv/mail/{sync_user}/Maildir" - ), - unit=primary, - ) - juju.exec( - ( - f"id -u {sync_user} >/dev/null 2>&1 || " - f"useradd -M -d /srv/mail/{sync_user} -s /usr/sbin/nologin {sync_user}" - ), - unit=secondary, - ) + # Set up a fresh test user + user = f"autosyncuser{token_hex(3)}" + password = token_hex(8) + for unit in (primary, secondary): + juju.exec(f"rm -rf /srv/mail/{user}", unit=unit) + _setup_mail_user(juju, [primary, secondary], user, password) - task = juju.run(unit=primary, action="force-sync", wait=100) - assert task.status == "completed" - assert task.results["result"] == "Sync completed successfully" + # Send email on primary + subject = f"Auto Sync Test {token_hex(4)}" + logging.info(f"Sending test email on primary with subject: {subject}") + juju.exec(f"echo 'test body' | mail -s '{subject}' {user}@localhost", unit=primary) - with pytest.raises(jubilant.TaskError) as exc_info: - juju.run(unit=secondary, action="force-sync", wait=100) - assert cast(jubilant.TaskError, exc_info.value).task.status == "failed" - logging.info("force-sync on Secondary correctly failed.") + previous_sync_mtime = _get_last_sync_mtime(juju, primary) + previous_cron_count = _get_sync_cron_run_count(juju, primary) - logging.info("HA Failover test passed.") + try: + # Lower sync schedule to every minute, wait for reconcile + logging.info("Setting sync-schedule to */1 * * * * (every minute)...") + juju.config(dovecot_charm, {"sync-schedule": "*/1 * * * *"}) + juju.wait(jubilant.all_active, timeout=5 * 60) + + logging.info("Waiting for first cron-triggered sync signal on primary...") + _wait_for_sync_trigger(juju, primary, previous_sync_mtime, previous_cron_count) + + # Verify email arrived on secondary via IMAP + secondary_ip = juju.status().apps[dovecot_charm].units[secondary].public_address + logging.info(f"Checking for email on secondary via IMAP at {secondary_ip}:993...") + synced = _check_mail_via_imap(secondary_ip, user, password, subject) + if not synced: + logging.info("Email not found after first cron sync; waiting for one more cron run...") + current_mtime = _get_last_sync_mtime(juju, primary) + current_cron_count = _get_sync_cron_run_count(juju, primary) + _wait_for_sync_trigger( + juju, + primary, + current_mtime, + current_cron_count, + timeout=2 * 60, + ) + synced = _check_mail_via_imap(secondary_ip, user, password, subject) + + assert synced, f"Email with subject '{subject}' not found on secondary after auto-sync" + finally: + # Reset sync-schedule to default + logging.info("Resetting sync-schedule to default...") + juju.config(dovecot_charm, {"sync-schedule": "*/30 * * * *"}) + juju.wait(jubilant.all_active, timeout=5 * 60) From 01aaf9442550a70f5092b5dd6a514e2be0788222 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Wed, 22 Apr 2026 09:52:40 +0300 Subject: [PATCH 33/39] chore: fmt --- dovecot-charm/tests/integration/test_ha.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dovecot-charm/tests/integration/test_ha.py b/dovecot-charm/tests/integration/test_ha.py index b0fcd60..a7b9f5a 100644 --- a/dovecot-charm/tests/integration/test_ha.py +++ b/dovecot-charm/tests/integration/test_ha.py @@ -112,7 +112,9 @@ def _wait_for_sync_trigger( deadline = time.time() + timeout while time.time() < deadline: current_mtime = _get_last_sync_mtime(juju, unit) - if current_mtime is not None and (previous_mtime is None or current_mtime > previous_mtime): + if current_mtime is not None and ( + previous_mtime is None or current_mtime > previous_mtime + ): return current_mtime current_cron_count = _get_sync_cron_run_count(juju, unit) From 54893224808837b590e3d642f7dc09fdb3204874 Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Wed, 22 Apr 2026 15:52:11 +0300 Subject: [PATCH 34/39] refactor(ha): enhance dual unit support in HA tests and improve sync logging --- .../templates/sync-to-secondary_cron.tmpl | 1 + dovecot-charm/tests/integration/conftest.py | 60 +++++++ dovecot-charm/tests/integration/test_ha.py | 161 +++++++++++------- 3 files changed, 156 insertions(+), 66 deletions(-) diff --git a/dovecot-charm/templates/sync-to-secondary_cron.tmpl b/dovecot-charm/templates/sync-to-secondary_cron.tmpl index cd2e1a2..c91b0b5 100644 --- a/dovecot-charm/templates/sync-to-secondary_cron.tmpl +++ b/dovecot-charm/templates/sync-to-secondary_cron.tmpl @@ -1 +1,2 @@ {{ schedule }} root /usr/local/bin/sync-to-secondary.sh 2>&1 | /usr/bin/logger -t sync-to-secondary + diff --git a/dovecot-charm/tests/integration/conftest.py b/dovecot-charm/tests/integration/conftest.py index cbd7c40..d74dfdd 100644 --- a/dovecot-charm/tests/integration/conftest.py +++ b/dovecot-charm/tests/integration/conftest.py @@ -140,3 +140,63 @@ def tls_charm(juju: jubilant.Juju) -> str: logging.info(f"{tls_app} already deployed, skipping deployment.") return tls_app + + +@pytest.fixture(scope="module") +def dovecot_charm_dual_unit( + charm: str, + juju: jubilant.Juju, + tls_charm: str, +) -> str: + """Build and deploy the charm.""" + logging.info(f"Checking for existing application {APP_NAME}...") + luks_key = token_hex(16) + + if not juju.status().apps.get(APP_NAME): + logging.info(f"Application {APP_NAME} not found, proceeding with deployment.") + + secret_id = juju.cli("add-secret", "dovecot-luks-key", f"key={luks_key}").strip() + logging.info(f"Created LUKS secret: {secret_id}") + + config = { + "mailname": "example.com", + "postmaster-address": "postmaster@example.com", + "primary-unit": f"{APP_NAME}/0", + "luks-auto-provisioning": True, + "luks-key": secret_id, + } + charm_path = charm if charm.startswith(("./", "/")) else f"./{charm}" + juju.deploy( + charm_path, + app=APP_NAME, + config=config, + constraints={"virt-type": "virtual-machine"}, + trust=True, + num_units=2, + ) + else: + if len(juju.status().apps[APP_NAME].units) < 2: + logging.info("Adding the second unit...") + juju.add_unit(APP_NAME, num_units=1) + + def two_units_active(status): + app = status.apps.get(APP_NAME) + if not app or len(app.units) < 2: + return False + return jubilant.all_active(status) + + logging.info("Waiting for 2 units to be active...") + juju.wait(two_units_active, timeout=10 * 60) + + juju.cli("grant-secret", "dovecot-luks-key", APP_NAME) + try: + logging.info("Adding TLS relation...") + juju.integrate(f"{APP_NAME}:certificates", f"{tls_charm}:certificates") + except Exception: + logging.info("TLS relation already there...") + logging.info("Waiting for active status...") + juju.wait( + lambda status: jubilant.all_active(status, APP_NAME, tls_charm), + timeout=10 * 60, + ) + return APP_NAME diff --git a/dovecot-charm/tests/integration/test_ha.py b/dovecot-charm/tests/integration/test_ha.py index a7b9f5a..c578e05 100644 --- a/dovecot-charm/tests/integration/test_ha.py +++ b/dovecot-charm/tests/integration/test_ha.py @@ -13,17 +13,6 @@ import pytest -def _get_unit_hostname(status, app_name, unit_name): - """Helper to get unit hostname from status.""" - try: - machine = status.apps[app_name].units[unit_name].machine - return status.machines[machine].hostname - except KeyError: - message = f"Could not determine hostname for unit {unit_name} from Juju status." - logging.error(message) - pytest.fail(message) - - def _check_mail_via_imap(unit_ip: str, user: str, password: str, subject: str) -> bool: """Poll IMAP on unit_ip until the email with the given subject is found.""" context = ssl.create_default_context() @@ -54,9 +43,20 @@ def _check_mail_via_imap(unit_ip: str, user: str, password: str, subject: str) - return False -def _setup_mail_user(juju: jubilant.Juju, units: list[str], user: str, password: str): - """Create a system user with a Maildir on each unit, set password on all units.""" - for unit in units: +def _setup_mail_user( + juju: jubilant.Juju, + primary: str, + secondary: str, + user: str, + password: str, +): + """Create a mail user on both units. + + The system account and password are created on both units so PAM auth works + on the secondary after sync. The Maildir is only initialised on the primary + so that dsync can replicate it to the secondary without GUID conflicts. + """ + for unit in (primary, secondary): juju.exec( ( f"id -u {user} >/dev/null 2>&1 || " @@ -67,13 +67,14 @@ def _setup_mail_user(juju: jubilant.Juju, units: list[str], user: str, password: juju.exec(f"echo '{user}:{password}' | chpasswd", unit=unit) juju.exec(f"usermod -aG mail {user}", unit=unit) - # Maildir only needs to exist on primary so doveadm backup has something to sync - primary = units[0] + # Maildir only on primary — dsync creates it on the secondary during the + # first sync. Pre-initialising it on the secondary would give INBOX a + # different GUID and cause doveadm backup to fail with + # "mailbox_delete failed: INBOX can't be deleted". juju.exec( ( - f"mkdir -p /srv/mail/{user}/Maildir/{{new,cur,tmp}} && " - f"chown -R {user}:{user} /srv/mail/{user} && " - f"chmod 700 /srv/mail/{user} /srv/mail/{user}/Maildir" + f"install -d -m 0700 -o {user} -g mail /srv/mail/{user} && " + f"doveadm mailbox create -u {user} INBOX 2>/dev/null || true" ), unit=primary, ) @@ -88,12 +89,45 @@ def _get_last_sync_mtime(juju: jubilant.Juju, unit: str) -> int | None: def _get_sync_cron_run_count(juju: jubilant.Juju, unit: str) -> int: - """Return count of sync-to-secondary cron executions recorded in syslog.""" - output = juju.exec( - "grep -c 'sync-to-secondary.sh' /var/log/syslog 2>/dev/null || true", + """Return count of sync-to-secondary cron executions from syslog or direct log. + + Works across charm versions: newer versions log via logger to syslog, + older versions redirect directly to /var/log/sync-to-secondary.log. + """ + # Try syslog first (newer charm versions with logger) + syslog_output = juju.exec( + "grep -c 'sync-to-secondary' /var/log/syslog 2>/dev/null || true", + unit=unit, + ).stdout.strip() + syslog_count = int(syslog_output) if syslog_output.isdigit() else 0 + + # Also check direct sync log file (older charm versions) + synclog_output = juju.exec( + "wc -l /var/log/sync-to-secondary.log 2>/dev/null | awk '{print $1}' || true", unit=unit, ).stdout.strip() - return int(output) if output.isdigit() else 0 + synclog_count = int(synclog_output) if synclog_output.isdigit() else 0 + + # Return the higher count (more reliable detector across versions) + return max(syslog_count, synclog_count) + + +def _get_sync_log_content(juju: jubilant.Juju, unit: str, lines: int = 20) -> str: + """Return last N sync-to-secondary lines from syslog for debugging.""" + output = juju.exec( + f"grep 'sync-to-secondary' /var/log/syslog 2>/dev/null | tail -n {lines} || echo 'No sync entries in syslog'", + unit=unit, + ).stdout + return output + + +def _get_cron_file_content(juju: jubilant.Juju, unit: str) -> str: + """Return content of the sync-to-secondary cron file for debugging.""" + output = juju.exec( + "cat /etc/cron.d/sync-to-secondary 2>/dev/null || echo 'Cron file not found'", + unit=unit, + ).stdout + return output def _wait_for_sync_trigger( @@ -104,12 +138,15 @@ def _wait_for_sync_trigger( timeout: int = 4 * 60, poll_interval: int = 5, ) -> int: - """Wait for a cron-triggered sync signal and return observed marker mtime. + """Wait until /srv/mail/.last-dsync mtime advances, indicating a completed sync. - Accepts either /srv/mail/.last-dsync mtime advance or syslog cron count increase - to work across charm revisions with different script logging/exit behavior. + The sync script touches .last-dsync only at the very end, so this is a + reliable end-of-sync marker. Falls back to syslog cron count increase only + when .last-dsync never existed (e.g. no users yet), but in that case we also + add a grace sleep so any in-progress dsync can finish. """ deadline = time.time() + timeout + cron_fired = False while time.time() < deadline: current_mtime = _get_last_sync_mtime(juju, unit) if current_mtime is not None and ( @@ -118,8 +155,11 @@ def _wait_for_sync_trigger( return current_mtime current_cron_count = _get_sync_cron_run_count(juju, unit) - if current_cron_count > previous_cron_count: - return current_mtime or 0 + if current_cron_count > previous_cron_count and not cron_fired: + logging.info( + "Cron fired (syslog count increased); waiting for .last-dsync to update..." + ) + cron_fired = True time.sleep(poll_interval) @@ -130,28 +170,16 @@ def _wait_for_sync_trigger( @pytest.mark.timeout(30 * 60) -def test_force_sync_action(juju: jubilant.Juju, dovecot_charm: str): +def test_force_sync_action(juju: jubilant.Juju, dovecot_charm_dual_unit: str): """force-sync action replicates mail from primary to secondary via doveadm backup.""" status = juju.status() - if len(status.apps[dovecot_charm].units) < 2: - logging.info("Adding the second unit...") - juju.add_unit(dovecot_charm, num_units=1) - - def two_units_active(status): - app = status.apps.get(dovecot_charm) - if not app or len(app.units) < 2: - return False - return jubilant.all_active(status) - - logging.info("Waiting for 2 units to be active...") - juju.wait(two_units_active, timeout=10 * 60) - - status = juju.status() - units = sorted(status.apps[dovecot_charm].units.keys(), key=lambda x: int(x.split("/")[-1])) + units = sorted( + status.apps[dovecot_charm_dual_unit].units.keys(), key=lambda x: int(x.split("/")[-1]) + ) primary, secondary = units[0], units[1] logging.info(f"Primary: {primary}, Secondary: {secondary}") - juju.config(dovecot_charm, {"primary-unit": primary}) + juju.config(dovecot_charm_dual_unit, {"primary-unit": primary}) juju.wait(jubilant.all_active, timeout=5 * 60) # Remove legacy HA test users that can break dsync on reruns. @@ -163,7 +191,7 @@ def two_units_active(status): password = token_hex(8) for unit in (primary, secondary): juju.exec(f"rm -rf /srv/mail/{user}", unit=unit) - _setup_mail_user(juju, [primary, secondary], user, password) + _setup_mail_user(juju, primary, secondary, user, password) # Send email on primary subject = f"Force Sync Test {token_hex(4)}" @@ -177,7 +205,7 @@ def two_units_active(status): assert task.results["result"] == "Sync completed successfully" # Verify email arrived on secondary via IMAP - secondary_ip = juju.status().apps[dovecot_charm].units[secondary].public_address + secondary_ip = juju.status().apps[dovecot_charm_dual_unit].units[secondary].public_address logging.info(f"Checking for email on secondary via IMAP at {secondary_ip}:993...") assert _check_mail_via_imap(secondary_ip, user, password, subject), ( f"Email with subject '{subject}' not found on secondary after force-sync" @@ -190,16 +218,16 @@ def two_units_active(status): logging.info("force-sync on secondary correctly failed.") -@pytest.mark.timeout(10 * 60) -def test_auto_sync(juju: jubilant.Juju, dovecot_charm: str): +def test_auto_sync(juju: jubilant.Juju, dovecot_charm_dual_unit: str): """Auto-sync via cron replicates mail from primary to secondary within 2 minutes.""" status = juju.status() - units = sorted(status.apps[dovecot_charm].units.keys(), key=lambda x: int(x.split("/")[-1])) - assert len(units) >= 2, "Need at least 2 units; run test_force_sync_action first" + units = sorted( + status.apps[dovecot_charm_dual_unit].units.keys(), key=lambda x: int(x.split("/")[-1]) + ) primary, secondary = units[0], units[1] logging.info(f"Ensuring primary-unit is set to {primary}...") - juju.config(dovecot_charm, {"primary-unit": primary}) + juju.config(dovecot_charm_dual_unit, {"primary-unit": primary}) juju.wait(jubilant.all_active, timeout=5 * 60) # Remove legacy HA test users that can break dsync on reruns. @@ -211,7 +239,7 @@ def test_auto_sync(juju: jubilant.Juju, dovecot_charm: str): password = token_hex(8) for unit in (primary, secondary): juju.exec(f"rm -rf /srv/mail/{user}", unit=unit) - _setup_mail_user(juju, [primary, secondary], user, password) + _setup_mail_user(juju, primary, secondary, user, password) # Send email on primary subject = f"Auto Sync Test {token_hex(4)}" @@ -224,32 +252,33 @@ def test_auto_sync(juju: jubilant.Juju, dovecot_charm: str): try: # Lower sync schedule to every minute, wait for reconcile logging.info("Setting sync-schedule to */1 * * * * (every minute)...") - juju.config(dovecot_charm, {"sync-schedule": "*/1 * * * *"}) + juju.config(dovecot_charm_dual_unit, {"sync-schedule": "*/1 * * * *"}) juju.wait(jubilant.all_active, timeout=5 * 60) + logging.info(f"Cron file after config change:\n{_get_cron_file_content(juju, primary)}") logging.info("Waiting for first cron-triggered sync signal on primary...") _wait_for_sync_trigger(juju, primary, previous_sync_mtime, previous_cron_count) # Verify email arrived on secondary via IMAP - secondary_ip = juju.status().apps[dovecot_charm].units[secondary].public_address + secondary_ip = juju.status().apps[dovecot_charm_dual_unit].units[secondary].public_address logging.info(f"Checking for email on secondary via IMAP at {secondary_ip}:993...") synced = _check_mail_via_imap(secondary_ip, user, password, subject) if not synced: - logging.info("Email not found after first cron sync; waiting for one more cron run...") - current_mtime = _get_last_sync_mtime(juju, primary) - current_cron_count = _get_sync_cron_run_count(juju, primary) - _wait_for_sync_trigger( - juju, - primary, - current_mtime, - current_cron_count, - timeout=2 * 60, - ) + logging.info("Email not found after first cron sync.") + logging.info(f"Sync log on primary:\n{_get_sync_log_content(juju, primary)}") + logging.info("Cron file content:") + logging.info(_get_cron_file_content(juju, primary)) + logging.info("Trying manual sync as fallback to verify sync mechanism works...") + juju.exec("/usr/local/bin/sync-to-secondary.sh", unit=primary) + time.sleep(15) synced = _check_mail_via_imap(secondary_ip, user, password, subject) + if not synced: + logging.info("Manual sync also failed. Checking sync log after manual run:") + logging.info(f"Sync log:\n{_get_sync_log_content(juju, primary, lines=30)}") assert synced, f"Email with subject '{subject}' not found on secondary after auto-sync" finally: # Reset sync-schedule to default logging.info("Resetting sync-schedule to default...") - juju.config(dovecot_charm, {"sync-schedule": "*/30 * * * *"}) + juju.config(dovecot_charm_dual_unit, {"sync-schedule": "*/30 * * * *"}) juju.wait(jubilant.all_active, timeout=5 * 60) From f803c209822e8ca93ba4d0c17aa69af696ae5d6d Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Fri, 24 Apr 2026 11:31:11 +0300 Subject: [PATCH 35/39] feat(ha): add SSH drop-in configuration for PermitRootLogin and validate cron schedule --- dovecot-charm/src/constants.py | 2 + dovecot-charm/src/ha.py | 75 ++++++++++++++++------ dovecot-charm/tests/integration/test_ha.py | 6 +- dovecot-charm/tests/unit/test_ha.py | 38 +++++++++++ 4 files changed, 98 insertions(+), 23 deletions(-) create mode 100644 dovecot-charm/tests/unit/test_ha.py diff --git a/dovecot-charm/src/constants.py b/dovecot-charm/src/constants.py index 30b4c1b..99de573 100644 --- a/dovecot-charm/src/constants.py +++ b/dovecot-charm/src/constants.py @@ -55,5 +55,7 @@ SYNC_TO_SECONDARY_CRONJOB_TEMPLATE = "sync-to-secondary_cron.tmpl" SSHD_CONFIG = Path("/etc/ssh/sshd_config") +SSHD_DROPIN_DIR = Path("/etc/ssh/sshd_config.d") +SSHD_DROPIN_FILE = SSHD_DROPIN_DIR / "99-dovecot-ha.conf" SSH_DIR = Path("/root/.ssh") SSH_HOST_KEY_FILE = Path("/etc/ssh/ssh_host_ed25519_key.pub") diff --git a/dovecot-charm/src/ha.py b/dovecot-charm/src/ha.py index fb0abae..6795f8e 100644 --- a/dovecot-charm/src/ha.py +++ b/dovecot-charm/src/ha.py @@ -21,6 +21,8 @@ SSH_DIR, SSH_HOST_KEY_FILE, SSHD_CONFIG, + SSHD_DROPIN_DIR, + SSHD_DROPIN_FILE, SYNC_TO_SECONDARY_CRONJOB_TARGET, SYNC_TO_SECONDARY_CRONJOB_TEMPLATE, SYNC_TO_SECONDARY_TARGET, @@ -126,32 +128,64 @@ def sync_known_hosts(charm: DovecotCharm) -> None: def ensure_root_ssh_login() -> None: - """Set PermitRootLogin to prohibit-password in sshd_config and reload sshd. + """Set PermitRootLogin via an sshd drop-in, validate, and reload sshd. + + Writes /etc/ssh/sshd_config.d/99-dovecot-ha.conf so that the distro-owned + sshd_config is not modified. Validates the resulting config with + ``sshd -t`` before reloading; rolls back the drop-in on failure. Raises: - HASetupError: If sshd reload fails. + HASetupError: If sshd validation or reload fails. """ if not SSHD_CONFIG.exists(): return - content = SSHD_CONFIG.read_text() - new_content = "" - found = False - for line in content.splitlines(keepends=True): - stripped = line.lstrip("#").strip() - if stripped.startswith("PermitRootLogin"): - new_content += "PermitRootLogin prohibit-password\n" - found = True + drop_in_content = "PermitRootLogin prohibit-password\n" + + if SSHD_DROPIN_FILE.exists() and SSHD_DROPIN_FILE.read_text() == drop_in_content: + return + + previous_exists = SSHD_DROPIN_FILE.exists() + previous_content = SSHD_DROPIN_FILE.read_text() if previous_exists else None + + SSHD_DROPIN_DIR.mkdir(mode=0o755, parents=True, exist_ok=True) + SSHD_DROPIN_FILE.write_text(drop_in_content) + + try: + subprocess.run( + ["/usr/sbin/sshd", "-t", "-f", str(SSHD_CONFIG)], + check=True, + capture_output=True, + text=True, + ) + except subprocess.CalledProcessError as e: + if previous_exists and previous_content is not None: + SSHD_DROPIN_FILE.write_text(previous_content) else: - new_content += line - if not found: - new_content += "\nPermitRootLogin prohibit-password\n" - if new_content != content: - SSHD_CONFIG.write_text(new_content) - try: - systemd.service_reload("ssh", restart_on_failure=True) - except subprocess.CalledProcessError as e: - raise HASetupError(f"Failed to reload sshd after config change: {e}") from e + SSHD_DROPIN_FILE.unlink(missing_ok=True) + raise HASetupError(f"Failed to validate sshd configuration: {e.stderr}") from e + + try: + systemd.service_reload("ssh", restart_on_failure=True) + except subprocess.CalledProcessError as e: + raise HASetupError(f"Failed to reload sshd after config change: {e}") from e + + +def _validate_cron_schedule(schedule: str) -> str: + """Validate and return a sanitised 5-field cron schedule string. + + Raises: + HASetupError: If the schedule contains unsafe characters or does not + consist of exactly 5 whitespace-separated fields. + """ + if "\n" in schedule or "\r" in schedule: + raise HASetupError(f"Invalid sync-schedule: value must not contain newlines: {schedule!r}") + fields = schedule.split() + if len(fields) != 5: + raise HASetupError( + f"Invalid sync-schedule: expected 5 fields, got {len(fields)}: {schedule!r}" + ) + return schedule def install_mail_sync_script(charm: DovecotCharm) -> None: @@ -186,8 +220,9 @@ def setup_mail_sync_cronjob(charm: DovecotCharm) -> None: return charm.unit.status = MaintenanceStatus("Setting up mail pool synchronization cronjob") + schedule = _validate_cron_schedule(charm.config.get("sync-schedule", "*/30 * * * *")) template_context = { - "schedule": charm.config.get("sync-schedule", "*/30 * * * *"), + "schedule": schedule, } template = charm.jinja.get_template(SYNC_TO_SECONDARY_CRONJOB_TEMPLATE) contents = template.render(template_context) diff --git a/dovecot-charm/tests/integration/test_ha.py b/dovecot-charm/tests/integration/test_ha.py index c578e05..0fed803 100644 --- a/dovecot-charm/tests/integration/test_ha.py +++ b/dovecot-charm/tests/integration/test_ha.py @@ -141,9 +141,9 @@ def _wait_for_sync_trigger( """Wait until /srv/mail/.last-dsync mtime advances, indicating a completed sync. The sync script touches .last-dsync only at the very end, so this is a - reliable end-of-sync marker. Falls back to syslog cron count increase only - when .last-dsync never existed (e.g. no users yet), but in that case we also - add a grace sleep so any in-progress dsync can finish. + reliable end-of-sync marker. Syslog cron count is checked only to log + that the cron job appears to have fired while we continue waiting for + .last-dsync to be updated. """ deadline = time.time() + timeout cron_fired = False diff --git a/dovecot-charm/tests/unit/test_ha.py b/dovecot-charm/tests/unit/test_ha.py new file mode 100644 index 0000000..7164ebf --- /dev/null +++ b/dovecot-charm/tests/unit/test_ha.py @@ -0,0 +1,38 @@ +# Copyright 2026 Canonical Ltd. +# See LICENSE file for licensing details. + +import pytest + +from exceptions import HASetupError +from ha import _validate_cron_schedule + + +class TestValidateCronSchedule: + def test_valid_standard_schedule(self): + assert _validate_cron_schedule("*/30 * * * *") == "*/30 * * * *" + + def test_valid_every_minute(self): + assert _validate_cron_schedule("*/1 * * * *") == "*/1 * * * *" + + def test_valid_specific_fields(self): + assert _validate_cron_schedule("0 4 * * 1") == "0 4 * * 1" + + def test_rejects_newline_in_schedule(self): + with pytest.raises(HASetupError, match="must not contain newlines"): + _validate_cron_schedule("*/30 * * * *\nbadline root /bin/evil") + + def test_rejects_carriage_return(self): + with pytest.raises(HASetupError, match="must not contain newlines"): + _validate_cron_schedule("*/30 * * * *\r") + + def test_rejects_too_few_fields(self): + with pytest.raises(HASetupError, match="expected 5 fields, got 4"): + _validate_cron_schedule("*/30 * * *") + + def test_rejects_too_many_fields(self): + with pytest.raises(HASetupError, match="expected 5 fields, got 6"): + _validate_cron_schedule("*/30 * * * * extra") + + def test_rejects_empty_string(self): + with pytest.raises(HASetupError, match="expected 5 fields, got 0"): + _validate_cron_schedule("") From 33bc065fdaa95e85673333661f4fe95b0c63447a Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Fri, 24 Apr 2026 12:57:52 +0300 Subject: [PATCH 36/39] feat(ha): enhance cron schedule validation and add unit tests for disallowed characters --- dovecot-charm/src/ha.py | 12 +++++++++- .../templates/sync-to-secondary.sh.tmpl | 2 +- dovecot-charm/tests/unit/test_ha.py | 23 +++++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/dovecot-charm/src/ha.py b/dovecot-charm/src/ha.py index 6795f8e..af66ff2 100644 --- a/dovecot-charm/src/ha.py +++ b/dovecot-charm/src/ha.py @@ -6,6 +6,7 @@ from __future__ import annotations import logging +import re import socket import subprocess # nosec import typing @@ -174,6 +175,9 @@ def ensure_root_ssh_login() -> None: def _validate_cron_schedule(schedule: str) -> str: """Validate and return a sanitised 5-field cron schedule string. + Each field may only contain digits and the characters ``* / , - ?``. + The returned value is whitespace-normalised (fields joined by single spaces). + Raises: HASetupError: If the schedule contains unsafe characters or does not consist of exactly 5 whitespace-separated fields. @@ -185,7 +189,13 @@ def _validate_cron_schedule(schedule: str) -> str: raise HASetupError( f"Invalid sync-schedule: expected 5 fields, got {len(fields)}: {schedule!r}" ) - return schedule + _allowed = re.compile(r"^[0-9\*/,\-?]+$") + for field in fields: + if not _allowed.match(field): + raise HASetupError( + f"Invalid sync-schedule: field {field!r} contains disallowed characters" + ) + return " ".join(fields) def install_mail_sync_script(charm: DovecotCharm) -> None: diff --git a/dovecot-charm/templates/sync-to-secondary.sh.tmpl b/dovecot-charm/templates/sync-to-secondary.sh.tmpl index 41e0914..6ee3594 100644 --- a/dovecot-charm/templates/sync-to-secondary.sh.tmpl +++ b/dovecot-charm/templates/sync-to-secondary.sh.tmpl @@ -17,4 +17,4 @@ if [ "$found" -eq 0 ]; then echo "No Maildir found under {{ mail_root }}; nothing to sync." >&2 exit 0 fi -touch {{ mail_root }}/.last-dsync +touch "{{ mail_root }}/.last-dsync" diff --git a/dovecot-charm/tests/unit/test_ha.py b/dovecot-charm/tests/unit/test_ha.py index 7164ebf..e8c75e4 100644 --- a/dovecot-charm/tests/unit/test_ha.py +++ b/dovecot-charm/tests/unit/test_ha.py @@ -36,3 +36,26 @@ def test_rejects_too_many_fields(self): def test_rejects_empty_string(self): with pytest.raises(HASetupError, match="expected 5 fields, got 0"): _validate_cron_schedule("") + + def test_rejects_command_substitution(self): + with pytest.raises(HASetupError, match="disallowed characters"): + _validate_cron_schedule("$(rm) * * * *") + + def test_rejects_backticks(self): + with pytest.raises(HASetupError, match="disallowed characters"): + _validate_cron_schedule("`id` * * * *") + + def test_rejects_semicolon(self): + with pytest.raises(HASetupError, match="disallowed characters"): + _validate_cron_schedule("*;id * * * *") + + def test_rejects_pipe(self): + with pytest.raises(HASetupError, match="disallowed characters"): + _validate_cron_schedule("*|cat * * * *") + + def test_rejects_alphabetic_field(self): + with pytest.raises(HASetupError, match="disallowed characters"): + _validate_cron_schedule("* * * * MON") + + def test_normalises_whitespace(self): + assert _validate_cron_schedule("*/30 * * * *") == "*/30 * * * *" From 599a59dc2669cb75aba584acf3840f56582241bb Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Fri, 24 Apr 2026 13:39:36 +0300 Subject: [PATCH 37/39] feat(ha): add sync_schedule field to DovecotConfig and validate cron schedule format --- dovecot-charm/src/charm.py | 2 +- dovecot-charm/src/dovecot_config.py | 21 +++++++++ dovecot-charm/src/ha.py | 33 ++----------- dovecot-charm/tests/unit/test_config.py | 63 +++++++++++++++++++++++++ dovecot-charm/tests/unit/test_ha.py | 61 ------------------------ 5 files changed, 88 insertions(+), 92 deletions(-) delete mode 100644 dovecot-charm/tests/unit/test_ha.py diff --git a/dovecot-charm/src/charm.py b/dovecot-charm/src/charm.py index d5133ca..0edca8c 100644 --- a/dovecot-charm/src/charm.py +++ b/dovecot-charm/src/charm.py @@ -179,7 +179,7 @@ def _reconcile(self, event): ha.sync_known_hosts(self) if self._is_primary: ha.install_mail_sync_script(self) - ha.setup_mail_sync_cronjob(self) + ha.setup_mail_sync_cronjob(self, dovecot_config) except HASetupError as e: self.unit.status = BlockedStatus(str(e)) return diff --git a/dovecot-charm/src/dovecot_config.py b/dovecot-charm/src/dovecot_config.py index 413baf8..f038d5e 100644 --- a/dovecot-charm/src/dovecot_config.py +++ b/dovecot-charm/src/dovecot_config.py @@ -4,6 +4,7 @@ """Dovecot charm configuration.""" import logging +import re from typing import TYPE_CHECKING from ops import ModelError, SecretNotFoundError @@ -62,6 +63,10 @@ class DovecotConfig(BaseModel): "", description="LUKS passphrase from the luks-key secret. Required when luks_auto_provisioning is true.", ) + sync_schedule: str = Field( + "*/30 * * * *", + description="Cron schedule for syncing mail from primary to secondary units.", + ) @field_validator("luks_key", mode="after") @classmethod @@ -72,6 +77,21 @@ def _validate_luks_key(cls, value: str, info: ValidationInfo) -> str: raise ValueError("luks-key secret must be set when luks-auto-provisioning is enabled") return value + @field_validator("sync_schedule", mode="after") + @classmethod + def _validate_sync_schedule(cls, value: str) -> str: + """Validate the cron schedule: 5 fields, safe characters only.""" + if "\n" in value or "\r" in value: + raise ValueError("sync-schedule must not contain newlines") + fields = value.split() + if len(fields) != 5: + raise ValueError(f"sync-schedule must have exactly 5 fields, got {len(fields)}") + allowed = re.compile(r"^[0-9\*/,\-?]+$") + for field in fields: + if not allowed.match(field): + raise ValueError(f"sync-schedule field {field!r} contains disallowed characters") + return " ".join(fields) + @field_validator("primary_unit", mode="after") @classmethod def _validate_primary_unit_exists(cls, value: str, info: ValidationInfo) -> str: @@ -115,6 +135,7 @@ def from_charm(cls, charm: "DovecotCharm") -> "DovecotConfig": "primary_unit": config.get("primary-unit"), "luks_auto_provisioning": luks_auto_provisioning, "luks_key": luks_key, + "sync_schedule": config.get("sync-schedule", "*/30 * * * *"), }, context={"charm": charm}, ) diff --git a/dovecot-charm/src/ha.py b/dovecot-charm/src/ha.py index af66ff2..80ead62 100644 --- a/dovecot-charm/src/ha.py +++ b/dovecot-charm/src/ha.py @@ -6,7 +6,6 @@ from __future__ import annotations import logging -import re import socket import subprocess # nosec import typing @@ -33,6 +32,7 @@ if typing.TYPE_CHECKING: from charm import DovecotCharm + from dovecot_config import DovecotConfig logger = logging.getLogger(__name__) @@ -172,32 +172,6 @@ def ensure_root_ssh_login() -> None: raise HASetupError(f"Failed to reload sshd after config change: {e}") from e -def _validate_cron_schedule(schedule: str) -> str: - """Validate and return a sanitised 5-field cron schedule string. - - Each field may only contain digits and the characters ``* / , - ?``. - The returned value is whitespace-normalised (fields joined by single spaces). - - Raises: - HASetupError: If the schedule contains unsafe characters or does not - consist of exactly 5 whitespace-separated fields. - """ - if "\n" in schedule or "\r" in schedule: - raise HASetupError(f"Invalid sync-schedule: value must not contain newlines: {schedule!r}") - fields = schedule.split() - if len(fields) != 5: - raise HASetupError( - f"Invalid sync-schedule: expected 5 fields, got {len(fields)}: {schedule!r}" - ) - _allowed = re.compile(r"^[0-9\*/,\-?]+$") - for field in fields: - if not _allowed.match(field): - raise HASetupError( - f"Invalid sync-schedule: field {field!r} contains disallowed characters" - ) - return " ".join(fields) - - def install_mail_sync_script(charm: DovecotCharm) -> None: """Render and install the mail pool synchronization script. @@ -218,7 +192,7 @@ def install_mail_sync_script(charm: DovecotCharm) -> None: host.write_file(SYNC_TO_SECONDARY_TARGET, contents, perms=0o755) -def setup_mail_sync_cronjob(charm: DovecotCharm) -> None: +def setup_mail_sync_cronjob(charm: DovecotCharm, dovecot_config: DovecotConfig) -> None: """Set up the mail pool synchronization cronjob. Skips writing and does not restart cron if the file content is unchanged. @@ -230,9 +204,8 @@ def setup_mail_sync_cronjob(charm: DovecotCharm) -> None: return charm.unit.status = MaintenanceStatus("Setting up mail pool synchronization cronjob") - schedule = _validate_cron_schedule(charm.config.get("sync-schedule", "*/30 * * * *")) template_context = { - "schedule": schedule, + "schedule": dovecot_config.sync_schedule, } template = charm.jinja.get_template(SYNC_TO_SECONDARY_CRONJOB_TEMPLATE) contents = template.render(template_context) diff --git a/dovecot-charm/tests/unit/test_config.py b/dovecot-charm/tests/unit/test_config.py index 24828bf..c1d3649 100644 --- a/dovecot-charm/tests/unit/test_config.py +++ b/dovecot-charm/tests/unit/test_config.py @@ -6,6 +6,7 @@ import pytest from ops.model import BlockedStatus +from pydantic import ValidationError from dovecot_config import DovecotConfig, DovecotConfigInvalidError @@ -53,3 +54,65 @@ def test_from_charm_primary_unit_does_not_exist_raises_value_error(base_state): with pytest.raises(DovecotConfigInvalidError, match="Primary unit does not exist"): DovecotConfig.from_charm(charm) + + +# Valid config kwargs shared by sync_schedule tests. +_VALID_BASE = { + "mailname": "example.com", + "postmaster_address": "admin@example.com", + "primary_unit": "dovecot-charm/0", +} + + +class TestSyncScheduleValidation: + def test_valid_default(self): + cfg = DovecotConfig(**_VALID_BASE) + assert cfg.sync_schedule == "*/30 * * * *" + + def test_valid_every_minute(self): + cfg = DovecotConfig(**_VALID_BASE, sync_schedule="*/1 * * * *") + assert cfg.sync_schedule == "*/1 * * * *" + + def test_valid_specific_fields(self): + cfg = DovecotConfig(**_VALID_BASE, sync_schedule="0 4 * * 1") + assert cfg.sync_schedule == "0 4 * * 1" + + def test_normalises_whitespace(self): + cfg = DovecotConfig(**_VALID_BASE, sync_schedule="*/30 * * * *") + assert cfg.sync_schedule == "*/30 * * * *" + + def test_rejects_newline(self): + with pytest.raises(ValidationError, match="must not contain newlines"): + DovecotConfig(**_VALID_BASE, sync_schedule="*/30 * * * *\nbad root /bin/evil") + + def test_rejects_too_few_fields(self): + with pytest.raises(ValidationError, match="exactly 5 fields"): + DovecotConfig(**_VALID_BASE, sync_schedule="*/30 * * *") + + def test_rejects_too_many_fields(self): + with pytest.raises(ValidationError, match="exactly 5 fields"): + DovecotConfig(**_VALID_BASE, sync_schedule="*/30 * * * * extra") + + def test_rejects_empty_string(self): + with pytest.raises(ValidationError, match="exactly 5 fields"): + DovecotConfig(**_VALID_BASE, sync_schedule="") + + def test_rejects_command_substitution(self): + with pytest.raises(ValidationError, match="disallowed characters"): + DovecotConfig(**_VALID_BASE, sync_schedule="$(rm) * * * *") + + def test_rejects_backticks(self): + with pytest.raises(ValidationError, match="disallowed characters"): + DovecotConfig(**_VALID_BASE, sync_schedule="`id` * * * *") + + def test_rejects_semicolon(self): + with pytest.raises(ValidationError, match="disallowed characters"): + DovecotConfig(**_VALID_BASE, sync_schedule="*;id * * * *") + + def test_rejects_pipe(self): + with pytest.raises(ValidationError, match="disallowed characters"): + DovecotConfig(**_VALID_BASE, sync_schedule="*|cat * * * *") + + def test_rejects_alphabetic_field(self): + with pytest.raises(ValidationError, match="disallowed characters"): + DovecotConfig(**_VALID_BASE, sync_schedule="* * * * MON") diff --git a/dovecot-charm/tests/unit/test_ha.py b/dovecot-charm/tests/unit/test_ha.py deleted file mode 100644 index e8c75e4..0000000 --- a/dovecot-charm/tests/unit/test_ha.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2026 Canonical Ltd. -# See LICENSE file for licensing details. - -import pytest - -from exceptions import HASetupError -from ha import _validate_cron_schedule - - -class TestValidateCronSchedule: - def test_valid_standard_schedule(self): - assert _validate_cron_schedule("*/30 * * * *") == "*/30 * * * *" - - def test_valid_every_minute(self): - assert _validate_cron_schedule("*/1 * * * *") == "*/1 * * * *" - - def test_valid_specific_fields(self): - assert _validate_cron_schedule("0 4 * * 1") == "0 4 * * 1" - - def test_rejects_newline_in_schedule(self): - with pytest.raises(HASetupError, match="must not contain newlines"): - _validate_cron_schedule("*/30 * * * *\nbadline root /bin/evil") - - def test_rejects_carriage_return(self): - with pytest.raises(HASetupError, match="must not contain newlines"): - _validate_cron_schedule("*/30 * * * *\r") - - def test_rejects_too_few_fields(self): - with pytest.raises(HASetupError, match="expected 5 fields, got 4"): - _validate_cron_schedule("*/30 * * *") - - def test_rejects_too_many_fields(self): - with pytest.raises(HASetupError, match="expected 5 fields, got 6"): - _validate_cron_schedule("*/30 * * * * extra") - - def test_rejects_empty_string(self): - with pytest.raises(HASetupError, match="expected 5 fields, got 0"): - _validate_cron_schedule("") - - def test_rejects_command_substitution(self): - with pytest.raises(HASetupError, match="disallowed characters"): - _validate_cron_schedule("$(rm) * * * *") - - def test_rejects_backticks(self): - with pytest.raises(HASetupError, match="disallowed characters"): - _validate_cron_schedule("`id` * * * *") - - def test_rejects_semicolon(self): - with pytest.raises(HASetupError, match="disallowed characters"): - _validate_cron_schedule("*;id * * * *") - - def test_rejects_pipe(self): - with pytest.raises(HASetupError, match="disallowed characters"): - _validate_cron_schedule("*|cat * * * *") - - def test_rejects_alphabetic_field(self): - with pytest.raises(HASetupError, match="disallowed characters"): - _validate_cron_schedule("* * * * MON") - - def test_normalises_whitespace(self): - assert _validate_cron_schedule("*/30 * * * *") == "*/30 * * * *" From 18de7ab51fdcc8e7478705e684e41c1ef82428fa Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Fri, 24 Apr 2026 14:45:53 +0300 Subject: [PATCH 38/39] feat(ha): ensure privsep directory exists for SSHD config checks --- dovecot-charm/src/ha.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dovecot-charm/src/ha.py b/dovecot-charm/src/ha.py index 80ead62..9043560 100644 --- a/dovecot-charm/src/ha.py +++ b/dovecot-charm/src/ha.py @@ -152,6 +152,9 @@ def ensure_root_ssh_login() -> None: SSHD_DROPIN_DIR.mkdir(mode=0o755, parents=True, exist_ok=True) SSHD_DROPIN_FILE.write_text(drop_in_content) + # sshd -t requires the privsep directory to exist even for config checks. + Path("/run/sshd").mkdir(mode=0o755, exist_ok=True) + try: subprocess.run( ["/usr/sbin/sshd", "-t", "-f", str(SSHD_CONFIG)], From d7665c11e0ad6cf588dee285ded9e4fe6f4974dd Mon Sep 17 00:00:00 2001 From: Ali Ugur Date: Fri, 24 Apr 2026 15:43:36 +0300 Subject: [PATCH 39/39] feat(ha): enhance sync_schedule validation to disallow question mark and update related tests --- dovecot-charm/src/dovecot_config.py | 2 +- dovecot-charm/src/ha.py | 54 +++++++++++++++++++++---- dovecot-charm/tests/unit/test_config.py | 4 ++ 3 files changed, 51 insertions(+), 9 deletions(-) diff --git a/dovecot-charm/src/dovecot_config.py b/dovecot-charm/src/dovecot_config.py index f038d5e..5b870fa 100644 --- a/dovecot-charm/src/dovecot_config.py +++ b/dovecot-charm/src/dovecot_config.py @@ -86,7 +86,7 @@ def _validate_sync_schedule(cls, value: str) -> str: fields = value.split() if len(fields) != 5: raise ValueError(f"sync-schedule must have exactly 5 fields, got {len(fields)}") - allowed = re.compile(r"^[0-9\*/,\-?]+$") + allowed = re.compile(r"^[0-9\*/,\-]+$") for field in fields: if not allowed.match(field): raise ValueError(f"sync-schedule field {field!r} contains disallowed characters") diff --git a/dovecot-charm/src/ha.py b/dovecot-charm/src/ha.py index 9043560..8a3ff3e 100644 --- a/dovecot-charm/src/ha.py +++ b/dovecot-charm/src/ha.py @@ -71,26 +71,42 @@ def setup_ssh_keys(charm: DovecotCharm) -> None: relation.data[charm.unit]["public_key"] = pub_key relation.data[charm.unit]["hostname"] = socket.gethostname() + # Publish own IP so peers can restrict root SSH to known addresses. + binding = charm.model.get_binding(PEER_RELATION_NAME) + if binding: + relation.data[charm.unit]["ip_address"] = str(binding.network.bind_address) + if SSH_HOST_KEY_FILE.exists(): host_key = SSH_HOST_KEY_FILE.read_text().strip() relation.data[charm.unit]["ssh_host_key"] = host_key def sync_authorized_keys(charm: DovecotCharm) -> None: - """Collect public keys from all peer units and write authorized_keys.""" + """Collect public keys and IPs from all peer units and write authorized_keys. + + Also calls ensure_root_ssh_login with the collected peer IPs so that root + SSH key login is restricted to known peer addresses only. + """ relation = charm.model.get_relation(PEER_RELATION_NAME) if not relation: return authorized_keys = [] + peer_ips: list[str] = [] for unit in relation.units: pk = relation.data[unit].get("public_key") if pk: authorized_keys.append(pk) + ip = relation.data[unit].get("ip_address") + if ip: + peer_ips.append(ip) our_pk = relation.data[charm.unit].get("public_key") if our_pk: authorized_keys.append(our_pk) + our_ip = relation.data[charm.unit].get("ip_address") + if our_ip: + peer_ips.append(our_ip) if not authorized_keys: return @@ -99,7 +115,7 @@ def sync_authorized_keys(charm: DovecotCharm) -> None: auth_file.write_text("\n".join(authorized_keys) + "\n") auth_file.chmod(0o600) - ensure_root_ssh_login() + ensure_root_ssh_login(peer_ips) def sync_known_hosts(charm: DovecotCharm) -> None: @@ -128,12 +144,18 @@ def sync_known_hosts(charm: DovecotCharm) -> None: known_hosts_file.chmod(0o600) -def ensure_root_ssh_login() -> None: - """Set PermitRootLogin via an sshd drop-in, validate, and reload sshd. +def ensure_root_ssh_login(peer_ips: list[str]) -> None: + """Set PermitRootLogin via an sshd drop-in restricted to peer addresses. - Writes /etc/ssh/sshd_config.d/99-dovecot-ha.conf so that the distro-owned - sshd_config is not modified. Validates the resulting config with - ``sshd -t`` before reloading; rolls back the drop-in on failure. + Writes /etc/ssh/sshd_config.d/99-dovecot-ha.conf with: + - A global ``PermitRootLogin no`` baseline. + - A ``Match Address`` block that permits ``prohibit-password`` only for + the supplied peer IP addresses. + + If no peer IPs are known yet the drop-in is removed (or not written) so + that root login remains governed by the distro default until peers are + available. Validates with ``sshd -t`` before reloading; rolls back on + failure. Raises: HASetupError: If sshd validation or reload fails. @@ -141,7 +163,23 @@ def ensure_root_ssh_login() -> None: if not SSHD_CONFIG.exists(): return - drop_in_content = "PermitRootLogin prohibit-password\n" + if not peer_ips: + # No peers known yet — remove our drop-in if present and return. + if SSHD_DROPIN_FILE.exists(): + SSHD_DROPIN_FILE.unlink() + try: + systemd.service_reload("ssh", restart_on_failure=True) + except subprocess.CalledProcessError as e: + raise HASetupError(f"Failed to reload sshd after config change: {e}") from e + return + + address_list = ",".join(sorted(set(peer_ips))) + drop_in_content = ( + "PermitRootLogin no\n" + "\n" + f"Match Address {address_list}\n" + " PermitRootLogin prohibit-password\n" + ) if SSHD_DROPIN_FILE.exists() and SSHD_DROPIN_FILE.read_text() == drop_in_content: return diff --git a/dovecot-charm/tests/unit/test_config.py b/dovecot-charm/tests/unit/test_config.py index c1d3649..5581769 100644 --- a/dovecot-charm/tests/unit/test_config.py +++ b/dovecot-charm/tests/unit/test_config.py @@ -116,3 +116,7 @@ def test_rejects_pipe(self): def test_rejects_alphabetic_field(self): with pytest.raises(ValidationError, match="disallowed characters"): DovecotConfig(**_VALID_BASE, sync_schedule="* * * * MON") + + def test_rejects_question_mark(self): + with pytest.raises(ValidationError, match="disallowed characters"): + DovecotConfig(**_VALID_BASE, sync_schedule="? * * * *")