diff --git a/poetry.lock b/poetry.lock index 1492f2873f..830d0c5328 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.3.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.4 and should not be changed by hand. [[package]] name = "allure-pytest" @@ -502,6 +502,18 @@ files = [ [package.dependencies] opentelemetry-api = "*" +[[package]] +name = "charmlibs-systemd" +version = "1.0.0" +description = "The charmlibs.systemd package." +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "charmlibs_systemd-1.0.0-py3-none-any.whl", hash = "sha256:37d4022e28f70f7a2a54fbff7c5694d25dc62dbb8680feffabde8c324a432199"}, + {file = "charmlibs_systemd-1.0.0.tar.gz", hash = "sha256:947e93b076e105509b190020ec16de051e9015c1eb12904192fb39489e0e1caa"}, +] + [[package]] name = "charset-normalizer" version = "3.4.7" @@ -3080,4 +3092,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "eabc76b77208cbb5dd78af27e1fb998f436203cf4688e0b82ec3e61b80ef0d2e" +content-hash = "6726f652f106215916c0ffb17d1e1146e60b688a1532ef75bbc63b57a9cd6a02" diff --git a/pyproject.toml b/pyproject.toml index 349db99637..ac39817970 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ psutil = "^7.2.2" charm-refresh = "^3.1.0.2" httpx = "^0.28.1" charmlibs-snap = "^1.0.1" +charmlibs-systemd = "^1.0.0" charmlibs-interfaces-tls-certificates = "^1.8.1" postgresql-charms-single-kernel = "16.1.11" diff --git a/refresh_versions.toml b/refresh_versions.toml index c5b2393b53..514c882744 100644 --- a/refresh_versions.toml +++ b/refresh_versions.toml @@ -6,6 +6,6 @@ name = "charmed-postgresql" [snap.revisions] # amd64 -x86_64 = "283" +x86_64 = "289" # arm64 -aarch64 = "282" +aarch64 = "288" diff --git a/src/charm.py b/src/charm.py index fd40d6ed2c..ff46a63386 100755 --- a/src/charm.py +++ b/src/charm.py @@ -411,13 +411,15 @@ def _post_snap_refresh(self, refresh: charm_refresh.Machines): Called after snap refresh """ try: - if raw_cert := self.get_secret(UNIT_SCOPE, "internal-cert"): - cert = load_pem_x509_certificate(raw_cert.encode()) - if ( + if ( + (raw_cert := self.get_secret(UNIT_SCOPE, "internal-cert")) + and (cert := load_pem_x509_certificate(raw_cert.encode())) + and ( cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value != self._unit_ip - ): - self.tls.generate_internal_peer_cert() + ) + ): + self.tls.generate_internal_peer_cert() except Exception: logger.exception("Unable to check or update internal cert") diff --git a/src/cluster.py b/src/cluster.py index f03f238294..24ffaaea1f 100644 --- a/src/cluster.py +++ b/src/cluster.py @@ -12,18 +12,15 @@ import re import shutil import subprocess -from asyncio import as_completed, create_task, run, wait -from contextlib import suppress from functools import cached_property from pathlib import Path -from ssl import CERT_NONE, create_default_context from typing import TYPE_CHECKING, Any, Literal, TypedDict import psutil import requests import tomli from charmlibs import snap -from httpx import AsyncClient, BasicAuth, HTTPError +from httpx import BasicAuth from jinja2 import Template from ops import BlockedStatus from pysyncobj.utility import TcpUtility, UtilityException @@ -58,7 +55,7 @@ POSTGRESQL_LOGS_PATH, TLS_CA_BUNDLE_FILE, ) -from utils import _change_owner, label2name, render_file +from utils import _change_owner, label2name, parallel_patroni_get_request, render_file logger = logging.getLogger(__name__) @@ -249,9 +246,28 @@ def cached_cluster_status(self): def cluster_status(self, alternative_endpoints: list | None = None) -> list[ClusterMember]: """Query the cluster status.""" + if not self._patroni_async_auth: + raise RetryError( + last_attempt=Future.construct(1, Exception("Unable to reach any units"), True) + ) + + # TODO we don't know the other cluster's ca + verify = not bool(alternative_endpoints) + if alternative_endpoints: + endpoints = alternative_endpoints + else: + endpoints = [] + if self.unit_ip: + endpoints.append(self.unit_ip) + for peer_ip in self.peers_ips: + endpoints.append(peer_ip) # Request info from cluster endpoint (which returns all members of the cluster). - if response := self.parallel_patroni_get_request( - f"/{PATRONI_CLUSTER_STATUS_ENDPOINT}", alternative_endpoints + if response := parallel_patroni_get_request( + f"/{PATRONI_CLUSTER_STATUS_ENDPOINT}", + endpoints, + f"{PATRONI_CONF_PATH}/{TLS_CA_BUNDLE_FILE}", + self._patroni_async_auth, + verify, ): logger.debug("API cluster_status: %s", response["members"]) return response["members"] @@ -295,54 +311,6 @@ def get_member_status(self, member_name: str) -> str: return member["state"] return "" - async def _httpx_get_request(self, url: str, verify: bool = True) -> dict[str, Any] | None: - if not self._patroni_async_auth: - return None - ssl_ctx = create_default_context() - if verify: - with suppress(FileNotFoundError): - ssl_ctx.load_verify_locations(cafile=f"{PATRONI_CONF_PATH}/{TLS_CA_BUNDLE_FILE}") - else: - ssl_ctx.check_hostname = False - ssl_ctx.verify_mode = CERT_NONE - async with AsyncClient( - auth=self._patroni_async_auth, timeout=API_REQUEST_TIMEOUT, verify=ssl_ctx - ) as client: - try: - return (await client.get(url)).raise_for_status().json() - except (HTTPError, ValueError): - return None - - async def _async_get_request( - self, uri: str, endpoints: list[str], verify: bool = True - ) -> dict[str, Any] | None: - tasks = [ - create_task(self._httpx_get_request(f"https://{ip}:8008{uri}", verify)) - for ip in endpoints - ] - for task in as_completed(tasks): - if result := await task: - for task in tasks: - task.cancel() - await wait(tasks) - return result - - def parallel_patroni_get_request( - self, uri: str, endpoints: list[str] | None = None - ) -> dict[str, Any] | None: - """Call all possible patroni endpoints in parallel.""" - if not endpoints: - endpoints = [] - if self.unit_ip: - endpoints.append(self.unit_ip) - for peer_ip in self.peers_ips: - endpoints.append(peer_ip) - verify = True - else: - # TODO we don't know the other cluster's ca - verify = False - return run(self._async_get_request(uri, endpoints, verify)) - def get_primary( self, unit_name_pattern=False, alternative_endpoints: list[str] | None = None ) -> str | None: diff --git a/src/constants.py b/src/constants.py index e3f6c84fa6..2d0bd51d9e 100644 --- a/src/constants.py +++ b/src/constants.py @@ -82,6 +82,17 @@ TRACING_PROTOCOL = "otlp_http" +# Watcher constants +WATCHER_OFFER_RELATION = "watcher-offer" +WATCHER_RELATION = "watcher" +WATCHER_USER = "watcher" + +# Labels are not confidential +WATCHER_PASSWORD_KEY = "watcher-password" # noqa: S105 +WATCHER_SECRET_LABEL = "watcher-secret" # noqa: S105 + +RAFT_PORT = 2222 + BACKUP_TYPE_OVERRIDES = {"full": "full", "differential": "diff", "incremental": "incr"} PLUGIN_OVERRIDES = {"audit": "pgaudit", "uuid_ossp": '"uuid-ossp"'} diff --git a/src/raft_controller.py b/src/raft_controller.py new file mode 100644 index 0000000000..ad009d97ae --- /dev/null +++ b/src/raft_controller.py @@ -0,0 +1,414 @@ +# Copyright 2026 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Raft controller management for PostgreSQL watcher. + +This module manages a Patroni raft_controller node that participates in +consensus without running PostgreSQL, providing the necessary third vote +for quorum in 2-node PostgreSQL clusters. + +Uses Patroni's own ``patroni_raft_controller`` from the charmed-postgresql +snap, which is the same battle-tested Raft implementation used by the +PostgreSQL nodes. This guarantees wire compatibility with Patroni's +KVStoreTTL class. + +The Raft service runs as a systemd service to ensure it persists between +charm hook invocations. +""" + +import logging +from contextlib import suppress +from ipaddress import IPv4Address +from shutil import rmtree +from typing import TYPE_CHECKING, TypedDict + +import psycopg2 +from charmlibs.systemd import ( + SystemdError, + daemon_reload, + service_disable, + service_enable, + service_restart, + service_running, + service_start, + service_stop, +) +from jinja2 import Template +from pysyncobj.utility import TcpUtility +from tenacity import RetryError, Retrying, stop_after_attempt, wait_fixed + +from cluster import ClusterMember +from constants import PATRONI_CLUSTER_STATUS_ENDPOINT +from utils import create_directory, parallel_patroni_get_request, render_file + +if TYPE_CHECKING: + from charm import PostgresqlOperatorCharm + +logger = logging.getLogger(__name__) + +# Base directory for all Raft instances. +# Must be under the snap's common path so that +# charmed-postgresql.patroni-raft-controller can access it. +RAFT_BASE_DIR = "/var/snap/charmed-postgresql/common/watcher-raft" +SERVICE_FILE = "/etc/systemd/system/watcher-raft@.service" + +# Default health check configuration +DEFAULT_RETRY_COUNT = 3 +DEFAULT_RETRY_INTERVAL_SECONDS = 7 +DEFAULT_QUERY_TIMEOUT_SECONDS = 5 +DEFAULT_CHECK_INTERVAL_SECONDS = 10 + +# TCP keepalive settings to detect dead connections quickly +TCP_KEEPALIVE_IDLE = 1 # Start keepalive probes after 1 second of idle +TCP_KEEPALIVE_INTERVAL = 1 # Send keepalive probes every 1 second +TCP_KEEPALIVE_COUNT = 3 # Consider connection dead after 3 failed probes + + +class ClusterStatus(TypedDict): + """Type definition for the cluster status mapping.""" + + running: bool + connected: bool + has_quorum: bool + leader: str | None + members: list[str] + + +def install_service() -> bool: + """Install the systemd template service for the Raft controller. + + Returns: + True if the service file was updated, False if unchanged. + """ + with open("templates/watcher.service.j2") as file: + template = Template(file.read()) + + rendered = template.render(config_file=RAFT_BASE_DIR) + render_file(SERVICE_FILE, rendered, 0o644, change_owner=False) + + # Reload systemd to pick up the new service + try: + daemon_reload() + logger.info(f"Installed systemd service {SERVICE_FILE}") + except SystemdError as e: + logger.error(f"Failed to reload systemd: {e}") + return False + + return True + + +class RaftController: + """Manages the Raft service for consensus participation. + + The Raft service runs as a systemd service to ensure it persists + between charm hook invocations. This is necessary because: + 1. Each hook invocation creates a new Python process + 2. pysyncobj requires a persistent process for Raft consensus + 3. The systemd service ensures the Raft node stays running + """ + + def __init__(self, charm: "PostgresqlOperatorCharm", instance_id: str = "default"): + """Initialize the Raft controller. + + Args: + charm: The PostgreSQL watcher charm instance. + instance_id: Unique identifier for this Raft instance. Used to + derive data directories, config files, and service names. + Defaults to "default" for backward compatibility. + + """ + self.charm = charm + self.instance_id = instance_id + + # Derive all paths from instance_id + self.data_dir = f"{RAFT_BASE_DIR}/{instance_id}" + self.config_file = f"{RAFT_BASE_DIR}/{instance_id}/patroni-raft.yaml" + self.ca_file = f"{RAFT_BASE_DIR}/{instance_id}/patroni-ca.pem" + self.service_name = f"watcher-raft@{instance_id}" + + def configure( + self, + self_port: int, + self_addr: str | None = None, + partner_addrs: list[str] | None = None, + password: str | None = None, + cas: str | None = None, + ) -> bool: + """Configure the Raft controller. + + Args: + self_port: This node's Raft port. + self_addr: This node's Raft address. + partner_addrs: List of partner Raft addresses. + password: Raft cluster password. + cas: Patroni CA bundle. + + Returns: + True if configuration changed, False if unchanged. + """ + if not partner_addrs: + partner_addrs = [] + + # Ensure data directory exists + create_directory(self.data_dir, 0o700) + create_directory(f"{self.data_dir}/raft", 0o700) + + if not self_addr or not password: + logger.warning("Cannot install service: not configured") + return False + + # Validate addresses to prevent injection into the systemd unit file + try: + IPv4Address(self_addr) + except Exception: + logger.error(f"Invalid self_addr format: {self_addr}") + return False + try: + for addr in partner_addrs: + IPv4Address(addr) + except Exception: + logger.error(f"Invalid partner address format: {addr}") + return False + + with open("templates/watcher.yml.j2") as file: + template = Template(file.read()) + + # Write Patroni-compatible YAML config (includes password) + rendered = template.render( + self_addr=self_addr, + self_port=self_port, + partner_addrs=partner_addrs, + password=password, + data_dir=self.data_dir, + ) + render_file(self.config_file, rendered, 0o600) + if cas: + render_file(self.ca_file, cas, 0o600) + + logger.info(f"Raft controller configured: self={self_addr}, partners={partner_addrs}") + return True + + def start(self) -> bool: + """Start the Raft controller service. + + Returns: + True if started successfully, False otherwise. + """ + if service_running(self.service_name): + logger.debug("Raft controller already running") + return True + + try: + # Enable and start the service + service_enable(self.service_name) + service_start(self.service_name) + logger.info(f"Started Raft controller service {self.service_name}") + return True + except SystemdError as e: + logger.error(f"Failed to start Raft controller: {e}") + return False + + def stop(self) -> bool: + """Stop the Raft controller service. + + Returns: + True if stopped successfully, False otherwise. + """ + if not service_running(self.service_name): + logger.debug("Raft controller not running") + return True + + try: + service_stop(self.service_name) + logger.info(f"Stopped Raft controller service {self.service_name}") + return True + except SystemdError as e: + logger.error(f"Failed to stop Raft controller: {e}") + return False + + def remove_service(self) -> bool: + """Disable and remove the Raft systemd service unit file.""" + if not self.stop(): + return False + + try: + service_disable(self.service_name) + except SystemdError as e: + logger.error(f"Failed to disable Raft controller service: {e}") + return False + + try: + rmtree(self.data_dir) + except Exception as e: + logger.error(f"Failed to remove Raft controller directory: {e}") + return False + + return True + + def restart(self) -> bool: + """Restart the Raft controller service. + + Returns: + True if restarted successfully, False otherwise. + """ + try: + service_restart(self.service_name) + logger.info(f"Restarted Raft controller service {self.service_name}") + return True + except SystemdError as e: + logger.error(f"Failed to restart Raft controller: {e}") + return False + + def get_status(self, self_port: int, password: str | None) -> ClusterStatus: + """Get the Raft controller status. + + Returns: + Dictionary with status information. + """ + is_running = service_running(self.service_name) + status: ClusterStatus = { + "running": is_running, + "connected": False, + "has_quorum": False, + "leader": None, + "members": [], + } + + if not password or not is_running: + return status + + # Query Raft status using pysyncobj TcpUtility + try: + utility = TcpUtility(password=password, timeout=3) + raft_status = utility.executeCommand(f"localhost:{self_port}", ["status"]) + status["connected"] = True + status["has_quorum"] = raft_status.get("has_quorum", False) + status["leader"] = ( + str(raft_status.get("leader")) if raft_status.get("leader") else None + ) + + # Extract member addresses from partner_node_status_server_* keys + prefix = "partner_node_status_server_" + members: list[str] = [str(raft_status["self"])] + for key in raft_status: + if key.startswith(prefix): + members.append(key[len(prefix) :]) + status["members"] = sorted(members) + return status + except Exception as e: + logger.debug(f"Error querying Raft status via TcpUtility: {e}") + + return status + + def check_all_endpoints(self, endpoints: list[str], password: str) -> dict[str, bool]: + """Test connectivity to all PostgreSQL endpoints. + + WARNING: This method uses blocking time.sleep() for retry intervals + (up to ~38s worst case with 2 endpoints). Only call from Juju actions, + never from hook handlers. + + Args: + endpoints: List of PostgreSQL unit IP addresses. + password: Password for the watcher user. + + Returns: + Dictionary mapping endpoint IP to health status data. + """ + results: dict[str, bool] = {} + for endpoint in endpoints: + results[endpoint] = self._check_endpoint_with_retries(endpoint, password) + + self._last_health_results = results + return results + + def _check_endpoint_with_retries(self, endpoint: str, password: str) -> bool: + """Check a single endpoint with retry logic. + + Per acceptance criteria: Repeat tests at least 3 times before + deciding that an instance is no longer reachable, waiting 7 seconds + between every try. + + Args: + endpoint: PostgreSQL endpoint IP address. + password: Password for the watcher user. + + Returns: + Dictionary with health status data. + """ + with suppress(RetryError): + for attempt in Retrying( + stop=stop_after_attempt(DEFAULT_RETRY_COUNT), + wait=wait_fixed(DEFAULT_RETRY_INTERVAL_SECONDS), + ): + with attempt: + if result := self._execute_health_query(endpoint, password): + logger.debug(f"Health check passed for {endpoint}") + return result + raise Exception(f"Cannot reach {endpoint}") + + logger.error(f"Endpoint {endpoint} unhealthy after {DEFAULT_RETRY_COUNT} attempts") + return False + + def _execute_health_query(self, endpoint: str, password: str) -> bool: + """Execute health check queries with TCP keepalive and timeout. + + Per acceptance criteria: + - Testing actual queries (SELECT 1) + - Using direct and reserved connections (no pgbouncer) + - Setting TCP keepalive to avoid hanging on dead connections + - Setting query timeout + + Args: + endpoint: PostgreSQL endpoint IP address. + password: Password for the watcher user. + + Returns: + Dictionary with health info (is_in_recovery, etc.) or None if failed. + """ + connection = None + result = False + try: + # Connect directly to PostgreSQL port 5432 (not pgbouncer 6432) + # Using the 'postgres' database which always exists + with ( + psycopg2.connect( + host=endpoint, + port=5432, + dbname="postgres", + user="watcher", + password=password, + connect_timeout=DEFAULT_QUERY_TIMEOUT_SECONDS, + # TCP keepalive settings per acceptance criteria + keepalives=1, + keepalives_idle=TCP_KEEPALIVE_IDLE, + keepalives_interval=TCP_KEEPALIVE_INTERVAL, + keepalives_count=TCP_KEEPALIVE_COUNT, + # Set options for query timeout + options=f"-c statement_timeout={DEFAULT_QUERY_TIMEOUT_SECONDS * 1000}", + ) as connection, + connection.cursor() as cursor, + ): + # Query recovery status to determine primary vs replica + cursor.execute("SELECT 1") + result = True + + except psycopg2.Error as e: + # Other database errors + logger.debug(f"Database error on {endpoint}: {e}") + finally: + if connection is not None: + try: + connection.close() + except psycopg2.Error as e: + logger.debug(f"Failed to close connection to {endpoint}: {e}") + return result + + def cluster_status(self, endpoints: list[str]) -> list[ClusterMember]: + """Query the cluster status.""" + # Request info from cluster endpoint (which returns all members of the cluster). + if response := parallel_patroni_get_request( + f"/{PATRONI_CLUSTER_STATUS_ENDPOINT}", endpoints, self.ca_file, None + ): + logger.debug("API cluster_status: %s", response["members"]) + return response["members"] + return [] diff --git a/src/relations/tls.py b/src/relations/tls.py index 4a0b9f9475..a7a313a1f3 100644 --- a/src/relations/tls.py +++ b/src/relations/tls.py @@ -217,7 +217,7 @@ def get_peer_ca_bundle(self) -> str: operator_ca = str(certs[0].ca) if certs else "" old_operator_ca = self.charm.get_secret(UNIT_SCOPE, "old-ca") or "" internal_ca = self.charm.get_secret(APP_SCOPE, "internal-ca") or "" - return "\n".join((operator_ca, old_operator_ca, internal_ca)) + return "\n".join((operator_ca, old_operator_ca, internal_ca)).strip() def generate_internal_peer_ca(self) -> None: """Generate internal peer CA using the tls lib.""" diff --git a/src/utils.py b/src/utils.py index 369dc173c9..d97700fda9 100644 --- a/src/utils.py +++ b/src/utils.py @@ -7,6 +7,14 @@ import pwd import secrets import string +from asyncio import as_completed, create_task, run, wait +from contextlib import suppress +from ssl import CERT_NONE, create_default_context +from typing import Any + +from httpx import AsyncClient, BasicAuth, HTTPError + +from constants import API_REQUEST_TIMEOUT def new_password() -> str: @@ -78,3 +86,46 @@ def _change_owner(path: str) -> None: user_database = pwd.getpwnam("_daemon_") # Set the correct ownership for the file or directory. os.chown(path, uid=user_database.pw_uid, gid=user_database.pw_gid) + + +async def _httpx_get_request( + url: str, cafile: str, auth: BasicAuth | None = None, verify: bool = True +) -> dict[str, Any] | None: + ssl_ctx = create_default_context() + if verify: + with suppress(FileNotFoundError): + ssl_ctx.load_verify_locations(cafile=cafile) + else: + ssl_ctx.check_hostname = False + ssl_ctx.verify_mode = CERT_NONE + async with AsyncClient(auth=auth, timeout=API_REQUEST_TIMEOUT, verify=ssl_ctx) as client: + try: + return (await client.get(url)).raise_for_status().json() + except (HTTPError, ValueError): + return None + + +async def _async_get_request( + uri: str, endpoints: list[str], cafile: str, auth: BasicAuth | None, verify: bool = True +) -> dict[str, Any] | None: + tasks = [ + create_task(_httpx_get_request(f"https://{ip}:8008{uri}", cafile, auth, verify)) + for ip in endpoints + ] + for task in as_completed(tasks): + if result := await task: + for task in tasks: + task.cancel() + await wait(tasks) + return result + + +def parallel_patroni_get_request( + uri: str, + endpoints: list[str], + cafile: str, + auth: BasicAuth | None = None, + verify: bool = True, +) -> dict[str, Any] | None: + """Call all possible patroni endpoints in parallel.""" + return run(_async_get_request(uri, endpoints, cafile, auth, verify)) diff --git a/templates/watcher.service.j2 b/templates/watcher.service.j2 new file mode 100644 index 0000000000..2df03728cf --- /dev/null +++ b/templates/watcher.service.j2 @@ -0,0 +1,19 @@ +[Unit] +Description=PostgreSQL Watcher Raft Service (%i) +After=network.target +Wants=network.target + +[Service] +Type=simple +# charmed-postgresql.patroni-raft-controller app lacks network interfaces +# in the snap profile, so run the controller under the patroni app profile. +ExecStart=/snap/bin/charmed-postgresql.patroni-raft-controller {{ config_file }}/%i/patroni-raft.yaml +Restart=always +RestartSec=5 +TimeoutStartSec=30 +TimeoutStopSec=30 +StandardOutput=journal +StandardError=journal + +[Install] +WantedBy=multi-user.target diff --git a/templates/watcher.yml.j2 b/templates/watcher.yml.j2 new file mode 100644 index 0000000000..a1708b2ba5 --- /dev/null +++ b/templates/watcher.yml.j2 @@ -0,0 +1,18 @@ +######################################################################################### +# [ WARNING ] +# watcher configuration file maintained by the postgres-operator +# local changes may be overwritten. +######################################################################################### +# For a complete reference of all the options for this configuration file, +# please refer to https://patroni.readthedocs.io/en/latest/SETTINGS.html. + +raft: + {% if partner_addrs -%} + partner_addrs: + {% endif -%} + {% for partner_addr in partner_addrs -%} + - {{ partner_addr }}:2222 + {% endfor %} + self_addr: '{{ self_addr }}:{{ self_port }}' + password: {{ password }} + data_dir: {{ data_dir }}/raft diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 970213713e..44ba084127 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -94,7 +94,7 @@ def patroni(harness, peers_ips): def test_get_member_ip(peers_ips, patroni): with patch( - "charm.Patroni.parallel_patroni_get_request", return_value=None + "cluster.parallel_patroni_get_request", return_value=None ) as _parallel_patroni_get_request: # No IP if no members assert patroni.get_member_ip(patroni.member_name) is None @@ -163,7 +163,7 @@ def test_dict_to_hba_string(harness, patroni): def test_get_primary(peers_ips, patroni): with ( patch( - "charm.Patroni.parallel_patroni_get_request", return_value=None + "cluster.parallel_patroni_get_request", return_value=None ) as _parallel_patroni_get_request, ): # No primary if no members diff --git a/tests/unit/test_raft_controller.py b/tests/unit/test_raft_controller.py new file mode 100644 index 0000000000..f167c6233d --- /dev/null +++ b/tests/unit/test_raft_controller.py @@ -0,0 +1,98 @@ +# Copyright 2026 Canonical Ltd. +# See LICENSE file for licensing details. + +from pathlib import Path +from unittest.mock import MagicMock, patch + +from charmlibs.systemd import SystemdError +from jinja2 import Template +from pytest import fixture + +from raft_controller import SERVICE_FILE, RaftController, install_service + + +@fixture +def controller(tmp_path: Path) -> RaftController: + controller = RaftController(MagicMock(), instance_id="rel42") + controller.data_dir = str(tmp_path / "watcher-raft" / "rel42") + controller.config_file = str(tmp_path / "watcher-raft" / "rel42" / "patroni-raft.yaml") + controller.service_name = "watcher-raft-rel42" + controller.service_file = str(tmp_path / "watcher-raft-rel42.service") + return controller + + +def test_configure(tmp_path: Path, controller: RaftController): + with open("templates/watcher.yml.j2") as file: + contents = file.read() + template = Template(contents) + + expected_content = template.render( + self_addr="10.0.0.1", + self_port=2222, + partner_addrs=["10.0.0.2"], + password="secret", + data_dir=f"{tmp_path}/watcher-raft/rel42", + ) + with ( + patch("raft_controller.render_file") as _render_file, + patch("raft_controller.create_directory") as _create_directory, + ): + assert controller.configure(2222, "10.0.0.1", ["10.0.0.2"], "secret") + + assert _create_directory.call_count == 2 + _create_directory.assert_any_call(f"{tmp_path}/watcher-raft/rel42", 0o700) + _create_directory.assert_any_call(f"{tmp_path}/watcher-raft/rel42/raft", 0o700) + _render_file.assert_called_once_with( + f"{tmp_path}/watcher-raft/rel42/patroni-raft.yaml", expected_content, 0o600 + ) + + +def test_remove_service_disables_unit_and_deletes_dir(tmp_path: Path, controller: RaftController): + Path(controller.service_file).write_text("[Unit]\nDescription=test\n") + + with ( + patch("raft_controller.service_running") as _service_running, + patch("raft_controller.service_stop") as _service_stop, + patch("raft_controller.service_disable") as _service_disable, + patch("raft_controller.rmtree") as _rmtree, + ): + assert controller.remove_service() + _service_running.assert_called_once_with(controller.service_name) + _service_stop.assert_called_once_with(controller.service_name) + _service_disable.assert_called_once_with(controller.service_name) + _rmtree.assert_called_once_with(controller.data_dir) + + +def test_install_service_returns_false_when_daemon_reload_fails( + tmp_path: Path, controller: RaftController +): + with ( + patch("raft_controller.daemon_reload") as _daemon_reload, + patch("raft_controller.render_file"), + patch("raft_controller.create_directory"), + ): + _daemon_reload.side_effect = SystemdError + + assert not install_service() + + +def test_install_service_uses_patroni_profile_execstart( + tmp_path: Path, controller: RaftController +): + with open("templates/watcher.service.j2") as file: + contents = file.read() + template = Template(contents) + + expected_content = template.render( + config_file="/var/snap/charmed-postgresql/common/watcher-raft" + ) + + with ( + patch("raft_controller.daemon_reload") as _daemon_reload, + patch("raft_controller.render_file") as _render_file, + patch("raft_controller.create_directory"), + ): + assert install_service() + + _render_file.assert_called_once_with(SERVICE_FILE, expected_content, 0o644, change_owner=False) + _daemon_reload.assert_called_once_with()