diff --git a/src/events/base_events.py b/src/events/base_events.py index 8020ba4..e2e1df0 100644 --- a/src/events/base_events.py +++ b/src/events/base_events.py @@ -33,7 +33,7 @@ Substrate, TLSState, ) -from statuses import CharmStatuses, ClusterStatuses, ScaleDownStatuses, StartStatuses +from statuses import CharmStatuses, ClusterStatuses, ScaleDownStatuses if TYPE_CHECKING: from charm import ValkeyCharm @@ -197,11 +197,8 @@ def _on_start(self, event: ops.StartEvent) -> None: event.defer() return - self.charm.status.set_running_status( - StartStatuses.SERVICE_STARTING.value, - scope="unit", - statuses_state=self.charm.state.statuses, - component_name=self.charm.cluster_manager.name, + self.charm.state.unit_server.update( + {"start_state": StartState.STARTING_WAITING_VALKEY.value} ) self.unit_fully_started.emit( is_primary=primary_endpoint diff --git a/src/managers/cluster.py b/src/managers/cluster.py index 6cded26..ec34c56 100644 --- a/src/managers/cluster.py +++ b/src/managers/cluster.py @@ -128,7 +128,11 @@ def reload_tls_settings(self, tls_config: dict[str, str]) -> None: def get_statuses(self, scope: Scope, recompute: bool = False) -> list[StatusObject]: """Compute the cluster manager's statuses.""" - status_list: list[StatusObject] = [] + status_list: list[StatusObject] = self.state.statuses.get( + scope=scope, + component=self.name, + running_status_only=True, + ).root # Peer relation not established yet, or model not built yet for unit or app if not self.state.cluster.model or not self.state.unit_server.model: diff --git a/src/managers/tls.py b/src/managers/tls.py index 59b0512..ca842e3 100644 --- a/src/managers/tls.py +++ b/src/managers/tls.py @@ -438,14 +438,14 @@ def get_statuses(self, scope: Scope, recompute: bool = False) -> list[StatusObje ): status_list.append(TLSStatuses.DISABLING_CLIENT_TLS_FAILED.value) - if self.state.cluster.tls_client_private_key and not self.state.client_tls_relation: - status_list.append(TLSStatuses.PRIVATE_KEY_BUT_NO_TLS.value) - if ( private_key_id := self.state.config.get(TLS_CLIENT_PRIVATE_KEY_CONFIG) ) and self.read_and_validate_private_key(str(private_key_id)) is None: status_list.append(TLSStatuses.PRIVATE_KEY_INVALID.value) + if self.state.cluster.tls_client_private_key and not self.state.client_tls_relation: + status_list.append(TLSStatuses.PRIVATE_KEY_BUT_NO_TLS.value) + if self.state.unit_server.tls_client_state == TLSState.TO_NO_TLS: status_list.append(TLSStatuses.DISABLING_CLIENT_TLS.value) diff --git a/src/statuses.py b/src/statuses.py index b080542..1a1205b 100644 --- a/src/statuses.py +++ b/src/statuses.py @@ -63,7 +63,6 @@ class StartStatuses(Enum): SERVICE_STARTING = StatusObject( status="maintenance", message="Waiting for Valkey to start...", - running="async", ) WAITING_FOR_SENTINEL_DISCOVERY = StatusObject( status="maintenance", diff --git a/src/workload_vm.py b/src/workload_vm.py index 0b24903..8673cab 100644 --- a/src/workload_vm.py +++ b/src/workload_vm.py @@ -93,6 +93,9 @@ def install(self, revision: str | None = None, retry_and_raise: bool = True) -> revision = str(SNAP_REVISIONS[platform.machine()]) try: + # TODO revesit this logic after snapd update is released + # refresh snapd to use candidate to bypass risv check issue. + snap.add("snapd", channel="candidate") # as long as 26.04 is not stable, we need to install the core26 snap from beta snap.add("core26", channel="beta") diff --git a/tests/integration/clients/requirer-charm/charmcraft.yaml b/tests/integration/clients/requirer-charm/charmcraft.yaml index f76d389..430fc21 100644 --- a/tests/integration/clients/requirer-charm/charmcraft.yaml +++ b/tests/integration/clients/requirer-charm/charmcraft.yaml @@ -103,20 +103,109 @@ actions: description: The username to use type: string + execute: + description: Execute an arbitrary Valkey command through the Glide client + params: + command: + description: The Valkey command to execute (e.g. "PING", "SET key value", "GET key") + type: string + config: + description: > + Serialized GlideClientConfiguration JSON produced by + glide_helpers.serialize_glide_config(). The charm connects using this + configuration directly, independent of any relation or glide-config + option. + type: string + get-credentials: description: Action for fetching all available credentials from relations. + start-continuous-writes: + description: > + Start a background daemon that continuously writes incrementing integers + to a Valkey list using the relation-provided credentials. The daemon + survives between action calls and can be stopped with + stop-continuous-writes. + params: + sleep-interval: + description: Seconds to sleep between writes (float, default 1.0) + type: number + default: 1.0 + clear-existing: + description: Delete any existing list values before starting (default true) + type: boolean + default: true + + get-continuous-writes-state: + description: > + Return the last written value and total count from the continuous-writes + state file without stopping the daemon. + + assert-continuous-writes-increasing: + description: > + Assert that the continuous-writes daemon is actively writing by sampling + the state file twice with a configurable wait between samples and + verifying the count has increased. + params: + wait: + description: Seconds to wait between the two state samples (default 10) + type: number + default: 10 + + clear-continuous-writes: + description: > + Delete the continuous-writes key from Valkey. Can be run while the daemon + is stopped to reset data between test runs. + + stop-continuous-writes: + description: > + Stop the continuous-writes daemon and return the last written value and + total count of successful writes. Use this after a disruptive operation to + retrieve stats for consistency verification. + params: + clear: + description: Delete continuous-writes data from Valkey after stopping (default + false) + type: boolean + default: false + + seed-data: + description: > + Seed Valkey with random 1 KB values using the relation-provided + credentials. Keys are written in batches of 5000 using MSET and named + "". + params: + target-gb: + description: Target amount of data to seed in GB (default 1.0) + type: number + default: 1.0 + key-prefix: + description: Prefix for generated keys (default "seed:key:") + type: string + default: "seed:key:" + config: options: data-interfaces-version: description: Version of data interfaces to use type: int default: 1 + glide-config: + description: > + JSON string with Glide connection options. When set, the charm uses + config-based connection instead of the Valkey relation. Expected keys: + endpoints (comma-separated "host:port" string), username (string), + password (string), tls_enabled (bool), cacert (base64-encoded PEM CA + certificate string), cert (base64-encoded PEM client certificate + string), key (base64-encoded PEM client private key string). + type: string + default: "" use-mtls: description: Flag to enable use of mutual TLS type: boolean default: false use-certificate-auth: - description: Flag to enable authentication via the common name of the client certificate + description: Flag to enable authentication via the common name of the client + certificate type: boolean default: false diff --git a/tests/integration/clients/requirer-charm/src/charm.py b/tests/integration/clients/requirer-charm/src/charm.py index d932177..18c80f9 100755 --- a/tests/integration/clients/requirer-charm/src/charm.py +++ b/tests/integration/clients/requirer-charm/src/charm.py @@ -5,8 +5,16 @@ """Charm the application.""" import asyncio +import base64 +import json import logging +import os +import signal import socket +import subprocess +import sys +import time +from pathlib import Path import ops from charmlibs.interfaces.tls_certificates import ( @@ -15,6 +23,9 @@ ) from charms.data_platform_libs.v0.data_interfaces import DatabaseCreatedEvent, DatabaseRequires from client import ValkeyClient +from continuous_writes import DaemonConfig, TlsConfig +from continuous_writes import clear_key as cw_clear +from cw_helpers import CWPath, cw_llen, wait_for_pid_exit from dpcharmlibs.interfaces import ( DataContractV1, RequirerCommonModel, @@ -25,6 +36,9 @@ ValkeyResponseModel, build_model, ) +from glide import GlideClient +from glide_helpers import deserialize_glide_config, parse_custom_command_result +from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -33,6 +47,25 @@ class RefreshTLSCertificatesEvent(ops.EventBase): """Event for refreshing peer TLS certificates.""" +class GlideConfig(BaseModel): + """Represents the glide-config charm configuration option.""" + + endpoints: str + username: str + password: str + tls_enabled: bool = False + cacert: str = "" + cert: str = "" + key: str = "" + + @classmethod + def from_json(cls, raw: str) -> "GlideConfig": + return cls.model_validate_json(raw) + + def to_json(self) -> str: + return self.model_dump_json() + + class RequirerCharm(ops.CharmBase): """Charm that acts as client for Valkey.""" @@ -89,7 +122,26 @@ def __init__(self, framework: ops.Framework): framework.observe(self.on.config_changed, self._on_config_changed) framework.observe(self.on.set_action, self._on_set_action) framework.observe(self.on.get_action, self._on_get_action) + framework.observe(self.on.execute_action, self._on_execute_action) framework.observe(self.on.get_credentials_action, self._on_get_credentials_action) + framework.observe(self.on.seed_data_action, self._on_seed_data_action) + framework.observe( + self.on.start_continuous_writes_action, self._on_start_continuous_writes_action + ) + framework.observe( + self.on.stop_continuous_writes_action, self._on_stop_continuous_writes_action + ) + framework.observe( + self.on.clear_continuous_writes_action, self._on_clear_continuous_writes_action + ) + framework.observe( + self.on.get_continuous_writes_state_action, + self._on_get_continuous_writes_state_action, + ) + framework.observe( + self.on.assert_continuous_writes_increasing_action, + self._on_assert_continuous_writes_increasing_action, + ) framework.observe(self.valkey_interface.on.endpoints_changed, self._on_endpoints_changed) @property @@ -114,8 +166,23 @@ def remote_responses(self) -> list[ResourceProviderModel] | None: ).requests @property - def credentials(self) -> dict[str | None, str | None]: - """Retrieve the client credentials provided by Valkey.""" + def _glide_config(self) -> GlideConfig | None: + """Parse the glide-config JSON option, or None if not set.""" + if not (raw := str(self.config.get("glide-config", "")).strip()): + return None + return GlideConfig.from_json(raw) + + @property + def _use_config(self) -> bool: + """Return True when glide-config is set.""" + return self._glide_config is not None + + @property + def credentials(self) -> dict[str, str | None]: + """Retrieve the client credentials from config or relation.""" + if cfg := self._glide_config: + return {cfg.username: cfg.password or None} + if self.data_interfaces_version == 0: if not self.valkey_relation: return {"": None} @@ -137,7 +204,10 @@ def credentials(self) -> dict[str | None, str | None]: @property def primary_endpoint(self) -> str | None: - """Retrieve the write-endpoints provided by Valkey.""" + """Retrieve the write-endpoints from config or relation.""" + if cfg := self._glide_config: + return cfg.endpoints or None + if self.data_interfaces_version == 0: if not self.valkey_relation: return None @@ -151,7 +221,10 @@ def primary_endpoint(self) -> str | None: @property def tls_enabled(self) -> bool: - """Retrieve the tls flag provided by Valkey.""" + """Retrieve the TLS flag from config or relation.""" + if cfg := self._glide_config: + return cfg.tls_enabled + if not self.valkey_relation: return False @@ -177,7 +250,10 @@ def use_mtls(self) -> bool: @property def tls_ca_cert(self) -> str | None: - """Retrieve the tls CA cert provided by Valkey.""" + """Retrieve the TLS CA cert from config or relation.""" + if cfg := self._glide_config: + return base64.b64decode(cfg.cacert).decode() if cfg.cacert else None + if self.data_interfaces_version == 0: if not self.valkey_relation: return None @@ -191,6 +267,10 @@ def tls_ca_cert(self) -> str | None: @property def certificate(self) -> str | None: + """Retrieve the client certificate from config or the certificates relation.""" + if cfg := self._glide_config: + return base64.b64decode(cfg.cert).decode() if cfg.cert else None + certificates, _ = self.certificates.get_assigned_certificates() if not certificates: return None @@ -199,6 +279,10 @@ def certificate(self) -> str | None: @property def private_key(self) -> str | None: + """Retrieve the client private key from config or the certificates relation.""" + if cfg := self._glide_config: + return base64.b64decode(cfg.key).decode() if cfg.key else None + _, private_key = self.certificates.get_assigned_certificates() if not private_key: return None @@ -207,11 +291,18 @@ def private_key(self) -> str | None: def get_valkey_client(self, user: str) -> ValkeyClient: """Get a valkey client.""" + if not self.primary_endpoint: + raise ValueError("No endpoint available.") + if not self.credentials: + raise ValueError("No credentials available.") + if self.tls_enabled and ( + not self.certificate or not self.private_key or not self.tls_ca_cert + ): + raise ValueError("TLS is enabled but certificates are not yet available.") return ValkeyClient( username="" if self.config.get("use-certificate-auth") else user, password="" if self.config.get("use-certificate-auth") else self.credentials.get(user), - host=self.primary_endpoint.split(":")[0], - port=int(self.primary_endpoint.split(":")[1]), + endpoints=self.primary_endpoint.split(","), tls_cert=self.certificate.encode() if self.use_mtls else None, tls_key=self.private_key.encode() if self.use_mtls else None, tls_ca_cert=self.tls_ca_cert.encode() if self.tls_enabled else None, @@ -221,10 +312,6 @@ def _on_start(self, event: ops.StartEvent) -> None: """Handle start event.""" self.unit.status = ops.ActiveStatus() - def _on_config_changed(self, event: ops.ConfigChangedEvent) -> None: - """Handle config changes.""" - self.refresh_tls_certificates_event.emit() - def _on_set_action(self, event: ops.ActionEvent) -> None: """Handle set action.""" if not self.valkey_relation: @@ -275,6 +362,37 @@ def _on_get_action(self, event: ops.ActionEvent) -> None: event.fail(f"Failed to read data: {e}") logger.error("Failed to read data: %s", e) + def _on_execute_action(self, event: ops.ActionEvent) -> None: + """Handle execute action.""" + if not (command := str(event.params.get("command", ""))): + event.fail("Parameter command is required.") + event.set_results({"ok": False}) + return + + args = command.split() + + try: + glide_config = deserialize_glide_config(str(event.params["config"])) + except Exception as e: + event.fail(f"Failed to deserialize config: {e}") + event.set_results({"ok": False}) + return + + async def _run(): + client = await GlideClient.create(glide_config) + try: + return await client.custom_command(args) + finally: + await client.close() + + try: + result = asyncio.run(_run()) + event.set_results( + {"ok": True, "result": json.dumps(parse_custom_command_result(result))} + ) + except Exception as e: + event.set_results({"ok": False, "result": json.dumps(str(e))}) + def _on_get_credentials_action(self, event: ops.ActionEvent) -> None: """Return the credentials an action response.""" if not self.valkey_relation: @@ -291,6 +409,238 @@ def _on_get_credentials_action(self, event: ops.ActionEvent) -> None: } ) + def _on_seed_data_action(self, event: ops.ActionEvent) -> None: + """Handle seed-data action.""" + if not self._use_config and not self.valkey_relation: + event.fail( + "The action can be run only after a relation is created or glide-config is set." + ) + event.set_results({"ok": False}) + return + + target_gb = float(event.params.get("target-gb", 1.0)) + key_prefix = str(event.params.get("key-prefix", "seed:key:")) + + user, _ = next(iter(self.credentials.items())) + client = self.get_valkey_client(user) + try: + keys_added = asyncio.run(client.seed_data(target_gb=target_gb, key_prefix=key_prefix)) + event.set_results({"ok": True, "keys-added": keys_added}) + except Exception as e: + event.fail(f"Failed to seed data: {e}") + logger.error("Failed to seed data: %s", e) + + def _on_start_continuous_writes_action(self, event: ops.ActionEvent) -> None: + """Handle start-continuous-writes action.""" + if not self._use_config and not self.valkey_relation: + event.fail( + "The action can be run only after a relation is created or glide-config is set." + ) + return + + if not self.primary_endpoint: + event.fail("No primary endpoint available.") + return + + if not self.credentials: + event.fail("No credentials available.") + return + + if self.tls_enabled: + if not self.certificate or not self.private_key or not self.tls_ca_cert: + event.fail("TLS is enabled but certificates are not yet available.") + return + + sleep_interval = float(event.params.get("sleep-interval", 1.0)) + clear_existing = bool(event.params.get("clear-existing", True)) + + # Fail if a daemon is already running + if CWPath.PID.value.exists(): + try: + pid = int(CWPath.PID.value.read_text().strip()) + os.kill(pid, 0) # check existence without signalling + event.fail(f"Continuous-writes daemon is already running with PID {pid}.") + return + except ProcessLookupError: + # Stale PID file — clean up and proceed + CWPath.PID.value.unlink(missing_ok=True) + except ValueError: + CWPath.PID.value.unlink(missing_ok=True) + + # Clear previous state so the new run starts fresh + CWPath.STATE.value.unlink(missing_ok=True) + + # Resolve the first available credential from the relation + username, password = next(iter(self.credentials.items())) + + tls_config = None + if self.tls_enabled: + CWPath.CERT.value.write_bytes(self.certificate.encode()) + CWPath.KEY.value.write_bytes(self.private_key.encode()) + CWPath.CA.value.write_bytes(self.tls_ca_cert.encode()) + tls_config = TlsConfig( + cert_path=str(CWPath.CERT.value), + key_path=str(CWPath.KEY.value), + ca_path=str(CWPath.CA.value), + ) + + DaemonConfig( + endpoints=self.primary_endpoint, + username=username, + password=password, + tls=tls_config, + initial_count=0, + clear_existing=clear_existing, + ).to_file(CWPath.CONFIG.value) + + daemon_script = Path(__file__).parent / "continuous_writes.py" + log_file = CWPath.LOG.value.open("w") + proc = subprocess.Popen( + [sys.executable, str(daemon_script), str(CWPath.CONFIG.value), str(sleep_interval)], + stdout=log_file, + stderr=log_file, + start_new_session=True, + ) + log_file.close() + logger.info( + "Started continuous-writes daemon with PID %d (log: %s)", + proc.pid, + CWPath.LOG.value, + ) + event.set_results({"ok": True, "pid": proc.pid}) + + def _on_stop_continuous_writes_action(self, event: ops.ActionEvent) -> None: + """Handle stop-continuous-writes action.""" + if not CWPath.PID.value.exists(): + event.fail("No continuous-writes daemon is running (PID file not found).") + return + + try: + pid = int(CWPath.PID.value.read_text().strip()) + os.kill(pid, signal.SIGTERM) + except ProcessLookupError: + logger.warning("Daemon PID %s was not running; reading last state.", pid) + except ValueError: + event.fail("PID file contained invalid data.") + return + except OSError as exc: + event.fail(f"Failed to signal daemon: {exc}") + return + + # Wait for the daemon to exit and flush its final state, with retries + if not wait_for_pid_exit(pid): + logger.warning( + "Daemon PID %d had to be force-killed; state file may be incomplete.", pid + ) + + if not CWPath.STATE.value.exists(): + event.fail("State file not found — the daemon may not have written anything.") + return + + try: + state = json.loads(CWPath.STATE.value.read_text()) + except (json.JSONDecodeError, OSError) as exc: + event.fail(f"Failed to read state file: {exc}") + return + + logger.info( + "Stopped continuous-writes daemon. last_written=%d, count=%d", + state["last_written"], + state["count"], + ) + + if bool(event.params.get("clear", False)): + try: + daemon_config = DaemonConfig.from_file(CWPath.CONFIG.value) + asyncio.run(cw_clear(daemon_config)) + except Exception as exc: + logger.warning("Failed to clear continuous-writes data: %s", exc) + + event.set_results( + { + "ok": True, + "last-written-value": state["last_written"], + "count": state["count"], + } + ) + + def _on_clear_continuous_writes_action(self, event: ops.ActionEvent) -> None: + """Handle clear-continuous-writes action.""" + if not CWPath.CONFIG.value.exists(): + event.fail("No continuous-writes config found — run start-continuous-writes first.") + return + + try: + daemon_config = DaemonConfig.from_file(CWPath.CONFIG.value) + asyncio.run(cw_clear(daemon_config)) + except Exception as exc: + event.fail(f"Failed to clear continuous-writes data: {exc}") + return + + event.set_results({"ok": True}) + + def _on_get_continuous_writes_state_action(self, event: ops.ActionEvent) -> None: + """Handle get-continuous-writes-state action.""" + if not CWPath.STATE.value.exists(): + event.fail("State file not found — the daemon may not have written anything yet.") + return + + try: + state = json.loads(CWPath.STATE.value.read_text()) + except (json.JSONDecodeError, OSError) as exc: + event.fail(f"Failed to read state file: {exc}") + return + + event.set_results( + { + "ok": True, + "last-written-value": state["last_written"], + "count": state["count"], + } + ) + + def _on_assert_continuous_writes_increasing_action(self, event: ops.ActionEvent) -> None: + """Handle assert-continuous-writes-increasing action.""" + if not CWPath.CONFIG.value.exists(): + event.fail("No continuous-writes config found — run start-continuous-writes first.") + return + + try: + config = DaemonConfig.from_file(CWPath.CONFIG.value) + except Exception as exc: + event.fail(f"Failed to load continuous-writes config: {exc}") + return + + try: + count_before = asyncio.run(cw_llen(config)) + except Exception as exc: + event.fail(f"Failed to read list length from Valkey: {exc}") + return + + wait = float(event.params.get("wait", 10.0)) + time.sleep(wait) + + try: + count_after = asyncio.run(cw_llen(config)) + except Exception as exc: + event.fail(f"Failed to read list length from Valkey after wait: {exc}") + return + + if count_after <= count_before: + event.fail( + f"Writes are not increasing: list length was {count_before} before and" + f" {count_after} after {wait}s." + ) + return + + event.set_results( + { + "ok": True, + "count-before": count_before, + "count-after": count_after, + } + ) + def _on_resource_created(self, event: ResourceCreatedEvent[ResourceProviderModel]) -> None: """Handle resource created event.""" logger.info("Resource created") @@ -305,6 +655,55 @@ def _on_database_created(self, event: DatabaseCreatedEvent) -> None: """Handle the event triggered by data-interfaces v0.""" logger.info("Database created") + def _on_config_changed(self, event: ops.ConfigChangedEvent) -> None: + """Hot-reload the continuous-writes daemon when glide-config changes.""" + self.refresh_tls_certificates_event.emit() + + if not self._use_config or not CWPath.PID.value.exists(): + return + + try: + current_config = DaemonConfig.from_file(CWPath.CONFIG.value) + except Exception as exc: + logger.warning("Failed to read current daemon config: %s", exc) + return + + if current_config.endpoints == self.primary_endpoint: + return + + logger.info( + "Endpoints changed from %s to %s; reloading continuous-writes daemon.", + current_config.endpoints, + self.primary_endpoint, + ) + + username, password = next(iter(self.credentials.items())) + tls_config = current_config.tls + if self.tls_enabled and self.certificate and self.private_key and self.tls_ca_cert: + CWPath.CERT.value.write_bytes(self.certificate.encode()) + CWPath.KEY.value.write_bytes(self.private_key.encode()) + CWPath.CA.value.write_bytes(self.tls_ca_cert.encode()) + tls_config = TlsConfig( + cert_path=str(CWPath.CERT.value), + key_path=str(CWPath.KEY.value), + ca_path=str(CWPath.CA.value), + ) + + DaemonConfig( + endpoints=self.primary_endpoint, + username=username, + password=password, + tls=tls_config, + initial_count=0, + ).to_file(CWPath.CONFIG.value) + + try: + pid = int(CWPath.PID.value.read_text().strip()) + os.kill(pid, signal.SIGUSR1) + logger.info("Sent SIGUSR1 to continuous-writes daemon PID %d.", pid) + except (ProcessLookupError, ValueError, OSError) as exc: + logger.warning("Failed to send SIGUSR1 to daemon: %s", exc) + if __name__ == "__main__": # pragma: nocover ops.main(RequirerCharm) diff --git a/tests/integration/clients/requirer-charm/src/client.py b/tests/integration/clients/requirer-charm/src/client.py index f5e051b..17134f8 100644 --- a/tests/integration/clients/requirer-charm/src/client.py +++ b/tests/integration/clients/requirer-charm/src/client.py @@ -3,7 +3,9 @@ """ValkeyClient utility class to connect to valkey servers.""" +import json import logging +import os from glide import ( AdvancedGlideClientConfiguration, @@ -23,15 +25,13 @@ class ValkeyClient: def __init__( self, username: str, - password: str, - host: str, - port: int, + password: str | None, + endpoints: list[str], tls_cert: bytes | None, tls_key: bytes | None, tls_ca_cert: bytes | None, ): - self.host = host - self.port = port + self.endpoints = endpoints self.user = username self.password = password self.tls_cert = tls_cert @@ -40,7 +40,11 @@ def __init__( async def create_client(self) -> GlideClient: """Initialize the Valkey client.""" - credentials = ServerCredentials(username=self.user, password=self.password) + addresses = [ + NodeAddress(host, int(port_str)) + for endpoint in self.endpoints + for host, port_str in [endpoint.rsplit(":", 1)] + ] tls_config = TlsAdvancedConfiguration( client_cert_pem=self.tls_cert if self.tls_cert else None, @@ -49,9 +53,9 @@ async def create_client(self) -> GlideClient: ) client_config = GlideClientConfiguration( - [NodeAddress(host=self.host, port=self.port)], + addresses, use_tls=True if self.tls_ca_cert else False, - credentials=credentials, + credentials=ServerCredentials(username=self.user, password=self.password), request_timeout=1000, # in milliseconds advanced_config=AdvancedGlideClientConfiguration(tls_config=tls_config), ) @@ -74,6 +78,55 @@ async def get_key(self, key: str) -> str: try: value = await client.get(key) - return value.decode() + return value.decode() if value else "" # Return empty string if key does not exist + finally: + await client.close() + + async def seed_data(self, target_gb: float = 1.0, key_prefix: str = "seed:key:") -> int: + """Seed Valkey with random data and return the number of keys written.""" + value_size_bytes = 1024 + batch_size = 5000 + total_keys = int(target_gb * 1024 * 1024 * 1024) // value_size_bytes + + random_data = os.urandom(value_size_bytes).hex()[:value_size_bytes] + keys_added = 0 + + client = await self.create_client() + try: + while keys_added < total_keys: + batch_end = min(keys_added + batch_size, total_keys) + data = {f"{key_prefix}{i}": random_data for i in range(keys_added, batch_end)} + result = await client.mset(data) + if result != "OK": + raise RuntimeError(f"mset failed: {result}") + keys_added = batch_end + logger.info("Seeding progress: %d/%d keys", keys_added, total_keys) + finally: + await client.close() + + return keys_added + + async def execute_command(self, args: list[str]) -> str: + """Execute an arbitrary Valkey command and return the result as a string.""" + client = await self.create_client() + + try: + result = await client.custom_command(args) + str_result = "" + if result is None: + str_result = "" + elif isinstance(result, bytes): + str_result = result.decode() + elif isinstance(result, list): + # Decode bytes in lists (e.g. from LRANGE) to return a JSON-serializable structure + str_result = [ + item.decode() if isinstance(item, bytes) else item for item in result + ] + else: + str_result = str(result) # Fallback to string conversion for other types + + return json.dumps( + str_result + ) # For other result types, return a JSON string representation finally: await client.close() diff --git a/tests/integration/clients/requirer-charm/src/continuous_writes.py b/tests/integration/clients/requirer-charm/src/continuous_writes.py new file mode 100644 index 0000000..ae0d53e --- /dev/null +++ b/tests/integration/clients/requirer-charm/src/continuous_writes.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +# Copyright 2026 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Continuous writes daemon for Valkey integration testing. + +Spawned by the requirer charm's start-continuous-writes action. Reads +connection config from a JSON file, writes incrementing integers to a +Valkey list, and tracks the last successfully written value atomically. + +Usage: + python3 continuous_writes.py [sleep_interval] + +The config JSON must contain: + endpoints - comma-separated "host:port,host:port,..." string + username - Valkey username + password - Valkey password + tls_enabled - bool (optional, default false) + cert_path - path to client cert PEM (required if tls_enabled) + key_path - path to client key PEM (required if tls_enabled) + ca_path - path to CA cert PEM (required if tls_enabled) + initial_count - int to start counter from (optional, default 0) + +On write failure the same counter value is retried until it succeeds before +advancing, so no gaps are introduced in the sequence. + +State is written atomically to STATE_PATH after each successful write: + {"last_written": N, "count": N} + +PID is written to PID_PATH on startup and removed on exit. +""" + +import asyncio +import json +import logging +import os +import signal +import sys +from pathlib import Path + +from glide import ( + AdvancedGlideClientConfiguration, + BackoffStrategy, + GlideClient, + GlideClientConfiguration, + NodeAddress, + ServerCredentials, + TlsAdvancedConfiguration, +) +from pydantic import BaseModel + +KEY = "cw_key" +CONFIG_PATH = Path("/tmp/cw_config.json") +STATE_PATH = Path("/tmp/cw_state.json") +PID_PATH = Path("/tmp/cw_daemon.pid") + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + stream=sys.stderr, +) +logger = logging.getLogger(__name__) + + +class TlsConfig(BaseModel): + """TLS certificate paths for the Glide client.""" + + cert_path: str + key_path: str + ca_path: str + + +class DaemonConfig(BaseModel): + """Connection configuration for the continuous-writes daemon.""" + + endpoints: str + username: str + password: str + tls: TlsConfig | None = None + initial_count: int = 0 + clear_existing: bool = False + + @classmethod + def from_file(cls, path: Path) -> "DaemonConfig": + """Load and validate config from a JSON file.""" + return cls.model_validate_json(path.read_text()) + + def to_file(self, path: Path) -> None: + """Serialise config to a JSON file.""" + path.write_text(self.model_dump_json()) + + +def _write_state_atomic(last_written: int, count: int) -> None: + """Write state file atomically using a temp-file + rename.""" + data = json.dumps({"last_written": last_written, "count": count}) + tmp = STATE_PATH.with_suffix(".tmp") + tmp.write_text(data) + tmp.rename(STATE_PATH) + + +async def _make_client(config: DaemonConfig) -> GlideClient: + addresses = [ + NodeAddress(host, int(port_str)) + for endpoint in config.endpoints.split(",") + for host, port_str in [endpoint.rsplit(":", 1)] + ] + + tls_cert = tls_key = tls_ca = None + if config.tls is not None: + tls_cert = Path(config.tls.cert_path).read_bytes() + tls_key = Path(config.tls.key_path).read_bytes() + tls_ca = Path(config.tls.ca_path).read_bytes() + + glide_config = GlideClientConfiguration( + addresses=addresses, + credentials=ServerCredentials( + username=config.username, + password=config.password, + ), + use_tls=config.tls is not None, + request_timeout=1000, + reconnect_strategy=BackoffStrategy(num_of_retries=1, factor=0, exponent_base=1), + advanced_config=AdvancedGlideClientConfiguration( + tls_config=TlsAdvancedConfiguration( + client_cert_pem=tls_cert, + client_key_pem=tls_key, + root_pem_cacerts=tls_ca, + use_insecure_tls=True if config.tls is not None else None, + ) + ), + ) + return await GlideClient.create(glide_config) + + +async def clear(client: GlideClient) -> None: + """Delete the continuous-writes list key from Valkey.""" + await client.delete([KEY]) + logger.info("Cleared existing values for key '%s'.", KEY) + + +async def _initial_count(config: DaemonConfig, client: GlideClient) -> tuple[int, int]: + """Return (counter, list_len) to start from, resuming from state file if present.""" + if config.clear_existing: + try: + await clear(client) + except Exception as exc: + logger.warning("Failed to clear existing values: %s", exc) + return config.initial_count, 0 + + counter = config.initial_count + if STATE_PATH.exists(): + try: + state = json.loads(STATE_PATH.read_text()) + counter = state.get("last_written", counter) + 1 + except (json.JSONDecodeError, KeyError): + pass + + count = 0 + try: + count = await client.llen(KEY) + except Exception: + pass + + return counter, count + + +def _try_reload(old: DaemonConfig) -> DaemonConfig: + """Re-read config from disk; log changes and return updated config or original on failure.""" + try: + new = DaemonConfig.from_file(CONFIG_PATH) + except Exception as exc: + logger.warning("Failed to reload config: %s", exc) + return old + + changes = [] + if old.endpoints != new.endpoints: + changes.append(f"endpoints: {old.endpoints!r} -> {new.endpoints!r}") + if old.username != new.username: + changes.append(f"username: {old.username!r} -> {new.username!r}") + if (old.tls is not None) != (new.tls is not None): + changes.append(f"tls_enabled: {old.tls is not None} -> {new.tls is not None}") + + if changes: + logger.info("Config reloaded — changes: %s", "; ".join(changes)) + else: + logger.info("Config reloaded — no changes detected.") + + return new + + +async def _close_client(client: GlideClient | None) -> None: + """Close client if not None, swallowing errors.""" + if client is not None: + try: + await client.close() + except Exception: + pass + + +async def clear_key(config: DaemonConfig) -> None: + """Connect to Valkey and delete the continuous-writes list key.""" + client = await _make_client(config) + try: + await clear(client) + finally: + await _close_client(client) + + +async def _write_one(client: GlideClient, counter: int) -> tuple[int, int]: + """Write one value, return (last_written, new_count).""" + new_len = await client.lpush(KEY, [str(counter)]) + if not new_len: + raise RuntimeError("LPUSH returned 0/None") + return counter, new_len + + +async def run(config: DaemonConfig, sleep_interval: float) -> None: + """Run the main write loop until SIGTERM/SIGINT.""" + stop = asyncio.Event() + reload = asyncio.Event() + + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, stop.set) + loop.add_signal_handler(signal.SIGINT, stop.set) + loop.add_signal_handler(signal.SIGUSR1, reload.set) + + client: GlideClient = await _make_client(config) + counter, count = await _initial_count(config, client) + last_written = counter - 1 + logger.info( + "Starting continuous writes from counter=%d (existing list len=%d)", counter, count + ) + + try: + while not stop.is_set(): + try: + if reload.is_set(): + reload.clear() + config = _try_reload(config) + await _close_client(client) + client = None + if client is None: + client = await _make_client(config) + last_written, count = await _write_one(client, counter) + _write_state_atomic(last_written, count) + logger.info("Wrote %d (list len=%d)", counter, count) + counter += 1 + except Exception as exc: + logger.warning("Write failed for counter=%d, will retry: %s", counter, exc) + # In standalone mode, Glide locks onto the primary node during initialization and does not auto-refresh. + # If the primary fails, the client will time out indefinitely until manually recreated, making long-term client reuse highly unreliable. + try: + await _close_client(client) + except Exception: + pass + client = None + + try: + await asyncio.wait_for(stop.wait(), timeout=sleep_interval) + except asyncio.TimeoutError: + pass + finally: + await _close_client(client) + + # Flush final state before exiting + _write_state_atomic(last_written, count) + logger.info("Daemon exiting — last_written=%d, count=%d", last_written, count) + + +if __name__ == "__main__": + config_path = Path(sys.argv[1]) if len(sys.argv) > 1 else CONFIG_PATH + sleep_interval = float(sys.argv[2]) if len(sys.argv) > 2 else 1.0 + + config = DaemonConfig.from_file(config_path) + + PID_PATH.write_text(str(os.getpid())) + try: + asyncio.run(run(config, sleep_interval)) + finally: + PID_PATH.unlink(missing_ok=True) diff --git a/tests/integration/clients/requirer-charm/src/cw_helpers.py b/tests/integration/clients/requirer-charm/src/cw_helpers.py new file mode 100644 index 0000000..0593f06 --- /dev/null +++ b/tests/integration/clients/requirer-charm/src/cw_helpers.py @@ -0,0 +1,70 @@ +# Copyright 2026 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Helpers for the continuous-writes daemon used by the requirer charm.""" + +import enum +import logging +import os +import signal +import time +from pathlib import Path + +from continuous_writes import KEY as CW_KEY +from continuous_writes import DaemonConfig +from continuous_writes import _make_client as _cw_make_client + +logger = logging.getLogger(__name__) + + +class CWPath(enum.Enum): + """Paths used by the continuous-writes daemon.""" + + CONFIG = Path("/tmp/cw_config.json") + STATE = Path("/tmp/cw_state.json") + PID = Path("/tmp/cw_daemon.pid") + LOG = Path("/tmp/cw_daemon.log") + CERT = Path("/tmp/cw_client.pem") + KEY = Path("/tmp/cw_client.key") + CA = Path("/tmp/cw_client_ca.pem") + + +def wait_for_pid_exit( + pid: int, poll_interval: int = 1, max_attempts: int = 10, force_kill: bool = True +) -> bool: + """Wait for a process to exit. + + Returns True if the process exited cleanly within max_attempts, False otherwise. + If force_kill is True and the process is still running after max_attempts, sends SIGKILL. + """ + for attempt in range(max_attempts): + time.sleep(poll_interval) + try: + os.kill(pid, 0) # signal 0 checks existence without sending a signal + except ProcessLookupError: + logger.info("Daemon PID %d exited after %d second(s).", pid, attempt * poll_interval) + return True + except OSError: + pass # EPERM — process exists but unowned; treat as still running + + logger.warning( + "Daemon PID %d did not exit after %d second(s).", + pid, + max_attempts * poll_interval, + ) + if force_kill: + logger.warning("Sending SIGKILL to daemon PID %d.", pid) + try: + os.kill(pid, signal.SIGKILL) + except OSError: + pass + return False + + +async def cw_llen(config: DaemonConfig) -> int: + """Return the current length of the continuous-writes list in Valkey.""" + client = await _cw_make_client(config) + try: + return await client.llen(CW_KEY) + finally: + await client.close() diff --git a/tests/integration/clients/requirer-charm/src/glide_helpers.py b/tests/integration/clients/requirer-charm/src/glide_helpers.py new file mode 100644 index 0000000..b872d6f --- /dev/null +++ b/tests/integration/clients/requirer-charm/src/glide_helpers.py @@ -0,0 +1,115 @@ +# Copyright 2026 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Serialization/deserialization helpers for GlideClientConfiguration objects. + +Converts a GlideClientConfiguration (and its nested objects) to/from a JSON +string so it can be passed as a Juju action parameter. + +Bytes fields are base64-encoded; enums are stored by name; nested Glide +objects are tagged with ``__class__`` for round-trip reconstruction. +""" + +import base64 +import json +from enum import Enum +from typing import Any + +from glide import ( + AdvancedGlideClientConfiguration, + BackoffStrategy, + GlideClientConfiguration, + NodeAddress, + ReadFrom, + ServerCredentials, + TlsAdvancedConfiguration, +) + +# Maps each Glide class to the set of fields that should be serialized. +# Tuple values like ``(list, NodeAddress)`` are documentation only — the +# serialize/deserialize logic recurses structurally, not via this type info. +SCHEMA: dict[type, dict[str, Any]] = { + GlideClientConfiguration: { + "addresses": (list, NodeAddress), + "use_tls": bool, + "request_timeout": (int, type(None)), + "read_from": ReadFrom, + "credentials": (ServerCredentials, type(None)), + "reconnect_strategy": (BackoffStrategy, type(None)), + "advanced_config": (AdvancedGlideClientConfiguration, type(None)), + }, + NodeAddress: { + "host": str, + "port": int, + }, + ServerCredentials: { + "username": (str, type(None)), + "password": (str, type(None)), + }, + BackoffStrategy: { + "num_of_retries": int, + "factor": int, + "exponent_base": int, + "jitter_percent": (int, type(None)), + }, + AdvancedGlideClientConfiguration: { + "connection_timeout": (int, type(None)), + "tls_config": (TlsAdvancedConfiguration, type(None)), + }, + TlsAdvancedConfiguration: { + "use_insecure_tls": bool, + "client_cert_pem": (bytes, type(None)), + "client_key_pem": (bytes, type(None)), + "root_pem_cacerts": (bytes, type(None)), + }, +} + +_GLIDE_CLASSES: dict[str, type] = {cls.__name__: cls for cls in SCHEMA} +_ENUM_CLASSES: dict[str, type[Enum]] = {"ReadFrom": ReadFrom} + + +def deserialize(d: Any) -> Any: + """Recursively deserialize a JSON-compatible structure back to Glide objects.""" + if d is None or not isinstance(d, (dict, list)): + return d + if isinstance(d, list): + return [deserialize(i) for i in d] + if "__bytes__" in d: + return base64.b64decode(d["__bytes__"]) + if "__enum__" in d: + cls = _ENUM_CLASSES[d["__enum__"]] + return cls[d["value"]] + if "__class__" in d: + cls = _GLIDE_CLASSES[d["__class__"]] + fields = {k: deserialize(v) for k, v in d.items() if k != "__class__"} + return cls(**fields) + return d + + +def deserialize_glide_config(payload: str) -> GlideClientConfiguration: + """Deserialize a JSON string back to a GlideClientConfiguration.""" + return deserialize(json.loads(payload)) + + +def parse_custom_command_result(result: Any) -> Any: + """Recursively convert a custom_command return value to a JSON-serializable form. + + Glide's custom_command can return bytes, lists (possibly nested), mappings, + integers, booleans, or None. bytes values are decoded as UTF-8 with a + fallback to base64 so the result is always a plain str. + """ + if result is None: + return None + if isinstance(result, bytes): + try: + return result.decode("utf-8") + except UnicodeDecodeError: + return base64.b64encode(result).decode("ascii") + if isinstance(result, list): + return [parse_custom_command_result(item) for item in result] + if isinstance(result, dict): + return { + parse_custom_command_result(k): parse_custom_command_result(v) + for k, v in result.items() + } + return result # int, float, bool, str diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d6366ea..d9ee2a8 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -8,35 +8,33 @@ import pytest from literals import Substrate -from tests.integration.continuous_writes import ContinuousWrites -from tests.integration.helpers import APP_NAME +from tests.integration.helpers import GLIDE_RUNNER_NAME, are_apps_active_and_agents_idle logger = logging.getLogger(__name__) -@pytest.fixture(scope="function") -def c_writes(juju: jubilant.Juju): - """Create instance of the ContinuousWrites.""" - app = APP_NAME - logger.info("Creating ContinuousWrites instance for app with name %s", app) - return ContinuousWrites(juju, app) - - -@pytest.fixture(scope="function") -def c_writes_runner(juju: jubilant.Juju, c_writes: ContinuousWrites): - """Start continuous write operations and clears writes at the end of the test.""" - c_writes.start() - yield - logger.info("Clearing continuous writes after test completion") - logger.info(c_writes.clear()) +@pytest.fixture +def glide_runner_charm(arch: str) -> str: + """Path to the charm file to use for testing.""" + # Return str instead of pathlib.Path since python-libjuju's model.deploy(), juju deploy, and + # juju bundle files expect local charms to begin with `./` or `/` to distinguish them from + # Charmhub charms. + return f"./tests/integration/clients/requirer-charm/requirer-charm_ubuntu@24.04-{arch}.charm" @pytest.fixture(scope="function") -async def c_writes_async_clean(c_writes: ContinuousWrites): - """Clear continuous write operations at the end of the test.""" - yield - logger.info("Clearing continuous writes after test completion") - logger.info(await c_writes.async_clear()) +def glide_runner(juju: jubilant.Juju, glide_runner_charm: str) -> None: + """Deploy continuous writes runner charm if not already deployed.""" + if GLIDE_RUNNER_NAME not in juju.status().apps: + juju.deploy(glide_runner_charm, app=GLIDE_RUNNER_NAME) + juju.wait( + lambda status: are_apps_active_and_agents_idle( + status, GLIDE_RUNNER_NAME, idle_period=30 + ), + timeout=600, + delay=5, + successes=3, + ) @pytest.fixture(scope="session") diff --git a/tests/integration/continuous_writes.py b/tests/integration/continuous_writes.py deleted file mode 100644 index 0a34337..0000000 --- a/tests/integration/continuous_writes.py +++ /dev/null @@ -1,394 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2026 Canonical Ltd. -# See LICENSE file for licensing details. - -import asyncio -import logging -import multiprocessing -import queue -import time -from contextlib import asynccontextmanager -from multiprocessing import log_to_stderr -from pathlib import Path -from types import SimpleNamespace -from typing import Optional - -import jubilant -from glide import ( - AdvancedGlideClientConfiguration, - BackoffStrategy, - GlideClient, - GlideClientConfiguration, - NodeAddress, - ServerCredentials, - TlsAdvancedConfiguration, -) -from tenacity import ( - retry, - stop_after_attempt, - wait_fixed, - wait_random, -) - -from literals import CLIENT_PORT, TLS_PORT, CharmUsers -from tests.integration.helpers import get_data_bag, get_password - -logger = logging.getLogger(__name__) - - -class WriteFailedError(Exception): - """Raised when a single write operation has failed.""" - - -def get_active_hostnames(juju: jubilant.Juju, app_name: str) -> str: - """Get hostnames of units in started state and not marked for scale down.""" - return ",".join( - [ - unit["private-ip"] - for unit in get_data_bag(juju, app_name, "valkey-peers").values() - if unit.get("start-state", "") == "started" - and unit.get("scale-down-state", None) is None - ] - ) - - -class ContinuousWrites: - """Utility class for managing continuous async writes to Valkey using GLIDE.""" - - KEY = "cw_key" - LAST_WRITTEN_VAL_PATH = "last_written_value" - VALKEY_PORT = 6379 - - def __init__( - self, - juju: jubilant.Juju, - app: str, - initial_count: int = 0, - in_between_sleep: float = 1.0, - tls_enabled: bool = False, - ): - self._juju = juju - self._app = app - self._is_stopped = True - self._event = None - self._queue = None - self._process = None - self._initial_count = initial_count - self._in_between_sleep = in_between_sleep - self._mp_ctx = multiprocessing.get_context("spawn") - self.tls_enabled = tls_enabled - - def _get_config(self) -> SimpleNamespace: - """Fetch current cluster configuration from Juju.""" - return SimpleNamespace( - endpoints=get_active_hostnames(self._juju, self._app), - valkey_password=get_password(self._juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=self.tls_enabled, - ) - - async def _create_glide_client(self, config: Optional[SimpleNamespace] = None) -> GlideClient: - """Asynchronously create and return a configured GlideClient.""" - conf = config or self._get_config() - addresses = [ - NodeAddress(host, TLS_PORT if conf.tls_enabled else CLIENT_PORT) - for host in conf.endpoints.split(",") - ] - - credentials = ServerCredentials( - username=CharmUsers.VALKEY_ADMIN.value, password=conf.valkey_password - ) - - tls_cert = tls_key = tls_ca_cert = None - if conf.tls_enabled: - # Read locally stored certificate files - with open("client.pem", "rb") as f: - tls_cert = f.read() - with open("client.key", "rb") as f: - tls_key = f.read() - with open("client_ca.pem", "rb") as f: - tls_ca_cert = f.read() - logger.info( - "TLS is enabled. Loaded client certificate, key, and CA cert for Glide client configuration." - ) - - tls_config = TlsAdvancedConfiguration( - client_cert_pem=tls_cert if conf.tls_enabled else None, - client_key_pem=tls_key if conf.tls_enabled else None, - root_pem_cacerts=tls_ca_cert if conf.tls_enabled else None, - use_insecure_tls=True if conf.tls_enabled else None, - ) - - glide_config = GlideClientConfiguration( - addresses=addresses, - client_name="continuous_writes_client", - request_timeout=1000, - credentials=credentials, - reconnect_strategy=BackoffStrategy(num_of_retries=1, factor=50, exponent_base=2), - use_tls=True if conf.tls_enabled else False, - advanced_config=AdvancedGlideClientConfiguration(tls_config=tls_config), - ) - - return await GlideClient.create(glide_config) - - @retry(wait=wait_fixed(5) + wait_random(0, 5), stop=stop_after_attempt(5)) - def start(self) -> None: - """Run continuous writes in the background.""" - if not self._is_stopped: - self.clear() - - self._is_stopped = False - # Create primitives using the spawn context - self._event = self._mp_ctx.Event() - self._queue = self._mp_ctx.Queue() - - last_written_file = Path(self.LAST_WRITTEN_VAL_PATH) - if not last_written_file.exists(): - last_written_file.write_text(str(self._initial_count)) - - self._process = self._mp_ctx.Process( - target=self._run_process, - name="continuous_writes", - args=(self._event, self._queue, self._initial_count, self._in_between_sleep), - ) - - self.update() - self._process.start() - - def update(self) -> None: - """Update cluster related conf (scaling, password changes).""" - if self._queue: - self._queue.put(self._get_config()) - - @retry(wait=wait_fixed(5) + wait_random(0, 5), stop=stop_after_attempt(5)) - def clear(self) -> SimpleNamespace | None: - """Stop writes and delete the tracking key/file.""" - result = None - if not self._is_stopped: - result = self.stop() - - try: - asyncio.run(self._async_delete()) - except Exception as e: - logger.warning("Failed to clear continuous writes data from Valkey: %s", e) - - last_written_file = Path(self.LAST_WRITTEN_VAL_PATH) - if last_written_file.exists(): - last_written_file.unlink() - return result - - @retry(wait=wait_fixed(5) + wait_random(0, 5), stop=stop_after_attempt(5)) - async def async_clear(self) -> SimpleNamespace | None: - """Stop writes and delete the tracking key/file.""" - result = None - if not self._is_stopped: - result = await self.async_stop() - - try: - await self._async_delete() - except Exception as e: - logger.warning("Failed to clear continuous writes data from Valkey: %s", e) - - last_written_file = Path(self.LAST_WRITTEN_VAL_PATH) - if last_written_file.exists(): - last_written_file.unlink() - return result - - async def _async_delete(self) -> None: - client = await self._create_glide_client() - try: - await client.delete([self.KEY]) - finally: - await client.close() - - def count(self) -> int: - """Return number of items in the list.""" - return asyncio.run(self._async_count()) - - async def _async_count(self) -> int: - client = await self._create_glide_client() - try: - return await client.llen(self.KEY) - finally: - await client.close() - - def max_stored_id(self) -> int: - """Return the most recently inserted ID (top of list).""" - return asyncio.run(self._async_max_stored_id()) - - async def _async_max_stored_id(self) -> int: - client = await self._create_glide_client() - try: - val = await client.lindex(self.KEY, 0) - return int(val.decode()) if val else 0 - finally: - await client.close() - - @retry(wait=wait_fixed(5) + wait_random(0, 5), stop=stop_after_attempt(5)) - def stop(self) -> SimpleNamespace: - """Stop the background process and return summary statistics.""" - if not self._is_stopped and self._process: - self._event.set() - self._process.join(timeout=30) - self._process.terminate() - self._is_stopped = True - - result = SimpleNamespace() - result.max_stored_id = self.max_stored_id() - result.count = self.count() - result.last_expected_id = int(Path(self.LAST_WRITTEN_VAL_PATH).read_text().strip()) - - return result - - @retry(wait=wait_fixed(5) + wait_random(0, 5), stop=stop_after_attempt(5)) - async def async_stop(self) -> SimpleNamespace: - """Stop the background process and return summary statistics.""" - if not self._is_stopped and self._process: - self._event.set() - self._process.join(timeout=30) - self._process.terminate() - self._is_stopped = True - - result = SimpleNamespace() - result.max_stored_id = await self._async_max_stored_id() - result.count = await self._async_count() - result.last_expected_id = int(Path(self.LAST_WRITTEN_VAL_PATH).read_text().strip()) - - return result - - @staticmethod - def _run_process(event, data_queue, starting_number: int, in_between_sleep: float): - """Start synchronously the asyncio event loop.""" - proc_logger = log_to_stderr() - proc_logger.setLevel(logging.INFO) - - # FIX 2: Do the blocking read synchronously BEFORE starting the async loop - initial_config = data_queue.get(block=True) - - asyncio.run( - ContinuousWrites._async_run( - event, data_queue, starting_number, initial_config, in_between_sleep, proc_logger - ) - ) - - @staticmethod - async def _async_run( - event, - data_queue, - starting_number: int, - initial_config: SimpleNamespace, - in_between_sleep: float, - proc_logger: logging.Logger, - ): - """Async loop for writing data continuously.""" - - async def _make_client(conf: SimpleNamespace) -> GlideClient: - addresses = [ - NodeAddress(host, TLS_PORT if conf.tls_enabled else CLIENT_PORT) - for host in conf.endpoints.split(",") - ] - - credentials = ServerCredentials( - username=CharmUsers.VALKEY_ADMIN.value, password=conf.valkey_password - ) - - tls_cert = tls_key = tls_ca_cert = None - if conf.tls_enabled: - # Read locally stored certificate files - with open("client.pem", "rb") as f: - tls_cert = f.read() - with open("client.key", "rb") as f: - tls_key = f.read() - with open("client_ca.pem", "rb") as f: - tls_ca_cert = f.read() - - tls_config = TlsAdvancedConfiguration( - client_cert_pem=tls_cert if conf.tls_enabled else None, - client_key_pem=tls_key if conf.tls_enabled else None, - root_pem_cacerts=tls_ca_cert if conf.tls_enabled else None, - use_insecure_tls=True if conf.tls_enabled else None, - ) - - glide_config = GlideClientConfiguration( - addresses=addresses, - client_name="continuous_writes_worker", - request_timeout=1000, - credentials=credentials, - reconnect_strategy=BackoffStrategy(num_of_retries=1, factor=50, exponent_base=2), - use_tls=True if conf.tls_enabled else False, - advanced_config=AdvancedGlideClientConfiguration(tls_config=tls_config), - ) - - return await GlideClient.create(glide_config) - - @asynccontextmanager - async def with_client(conf: SimpleNamespace): - client = await _make_client(conf) - try: - yield client - finally: - await client.close() - - current_val = starting_number - last_written_value = starting_number - config = initial_config - - proc_logger.info("Starting continuous async writes from %s", current_val) - - try: - while not event.is_set(): - try: - config = data_queue.get_nowait() - proc_logger.info("Configuration updated, client reconnected.") - except queue.Empty: - pass - - try: - proc_logger.info("Writing value: %s", current_val) - proc_logger.info("Current endpoints=%s", config.endpoints) - async with with_client(config) as client: - if not ( - res := await asyncio.wait_for( - client.lpush(ContinuousWrites.KEY, [str(current_val)]), timeout=5 - ) - ): - raise WriteFailedError("LPUSH returned 0/None") - proc_logger.info("Length after write: %s", res) - last_written_value = current_val - except Exception as e: - proc_logger.warning("Write failed at %s: %s", current_val, e) - finally: - await asyncio.sleep(in_between_sleep) - if event.is_set(): - break - - current_val += 1 - - finally: - Path(ContinuousWrites.LAST_WRITTEN_VAL_PATH).write_text(str(last_written_value)) - proc_logger.info("Continuous writes process exiting.") - - -if __name__ == "__main__": - import jubilant - - juju_env = jubilant.Juju(model="testing") - cw = ContinuousWrites(juju=juju_env, app="valkey", in_between_sleep=0.5) - cw.clear() - cw.start() - # stop on ctrl + C or after some time - hostnames = get_active_hostnames(juju_env, "valkey") - try: - while True: - time.sleep(1) - if new_hostnames := get_active_hostnames(juju_env, "valkey") != hostnames: - logger.info( - "Hostnames changed from %s to %s, updating continuous writes client.", - hostnames, - new_hostnames, - ) - hostnames = new_hostnames - cw.update() - except KeyboardInterrupt: - pass - stats = cw.clear() - print(f"Stopped. Stats: {stats}") diff --git a/tests/integration/cw_helpers.py b/tests/integration/cw_helpers.py index 3fbb998..d6189c6 100644 --- a/tests/integration/cw_helpers.py +++ b/tests/integration/cw_helpers.py @@ -2,86 +2,222 @@ # Copyright 2025 Canonical Ltd. # See LICENSE file for licensing details. -import asyncio +import base64 import json import logging -import subprocess from pathlib import Path - -from tests.integration.continuous_writes import ContinuousWrites -from tests.integration.helpers import create_valkey_client, exec_valkey_cli +from typing import NamedTuple + +import jubilant + +from literals import CLIENT_PORT, TLS_PORT, Substrate +from tests.integration.helpers import ( + APP_NAME, + GLIDE_RUNNER_NAME, + TLS_CA_FILE, + TLS_CERT_FILE, + TLS_KEY_FILE, + CharmUsers, + download_client_certificate_from_unit, + exec_valkey_cli, + get_cluster_addresses, + get_password, +) logger = logging.getLogger(__name__) -# WRITES_LAST_WRITTEN_VAL_PATH = "last_written_value" -# KEY = "cw_key" -KEY = ContinuousWrites.KEY -WRITES_LAST_WRITTEN_VAL_PATH = ContinuousWrites.LAST_WRITTEN_VAL_PATH +class ContinuousWritesStats(NamedTuple): + last_written_value: int + total_count: int -def start_continuous_writes( - endpoints: str, - valkey_user: str, - valkey_password: str, - sentinel_user: str, - sentinel_password: str, +KEY = "cw_key" + + +def configure_cw_runner( + juju: jubilant.Juju, + app: str = GLIDE_RUNNER_NAME, + valkey_app: str = APP_NAME, + tls_enabled: bool = False, + substrate: Substrate = Substrate.VM, ) -> None: - """Create a subprocess instance of `continuous writes` and start writing data to valkey.""" - subprocess.Popen( - [ - "python3", - "tests/integration/continuous_writes.py", - endpoints, - valkey_user, - valkey_password, - sentinel_user, - sentinel_password, + """Configure the continuous writes runner charm to connect to Valkey via config options. + + Endpoints and the admin password are fetched automatically from the Juju + model. When ``tls_enabled`` is True, client certificates are downloaded + from a Valkey unit and passed as base64-encoded strings. + + Args: + juju: Juju client instance. + app: Name of the continuous writes runner charm application to configure. + valkey_app: Name of the Valkey application to fetch endpoints from. + tls_enabled: Whether TLS is enabled. + substrate: The substrate type (VM or Kubernetes). + """ + if substrate == Substrate.VM: + addresses = get_cluster_addresses(juju, valkey_app) + else: + # for k8s we construct the hostname + addresses = [ + unit_name.replace("/", "-") + "." + valkey_app + "-endpoints" + for unit_name in juju.status().get_units(valkey_app) ] + + port = TLS_PORT if tls_enabled else CLIENT_PORT + endpoints = ",".join(f"{h}:{port}" for h in addresses) + password = get_password(juju, user=CharmUsers.VALKEY_ADMIN) + + cacert = cert = key = "" + if tls_enabled: + download_client_certificate_from_unit(juju, app_name=valkey_app) + cacert = base64.b64encode(Path(TLS_CA_FILE).read_bytes()).decode() + cert = base64.b64encode(Path(TLS_CERT_FILE).read_bytes()).decode() + key = base64.b64encode(Path(TLS_KEY_FILE).read_bytes()).decode() + + glide_config = json.dumps( + { + "endpoints": endpoints, + "username": CharmUsers.VALKEY_ADMIN.value, + "password": password, + "tls_enabled": tls_enabled, + "cacert": cacert, + "cert": cert, + "key": key, + } ) + juju.config(app=app, values={"glide-config": glide_config}) -def stop_continuous_writes() -> None: - """Shut down the subprocess instance of the `continuous writes`.""" - proc = subprocess.Popen(["pkill", "-15", "-f", "continuous_writes.py"]) - proc.communicate() +def start_continuous_writes( + juju: jubilant.Juju, + unit: str = f"{GLIDE_RUNNER_NAME}/leader", + sleep_interval: float = 1.0, + config: dict | None = None, + clear: bool = True, +) -> int: + """Trigger the start-continuous-writes action on the requirer charm unit. + + Connection info is taken from the Valkey relation by default. To use + config options instead, pass a ``config`` dict; the options are applied + to the application before the action runs. + + Args: + juju: Juju client instance. + unit: Unit name (e.g. ``"requirer-charm/0"``). + sleep_interval: Seconds to sleep between writes. + config: Optional charm config values to set before starting. + clear: Delete any existing list values before starting. + + Returns: + PID of the spawned continuous-writes daemon. + """ + if config: + app = unit.split("/")[0] + juju.config(app=app, values=config) + + result = juju.run( + unit, + "start-continuous-writes", + params={"sleep-interval": sleep_interval, "clear-existing": clear}, + ) + assert result.results.get("ok"), f"start-continuous-writes failed: {result}" + pid = int(result.results["pid"]) + logger.info("Continuous-writes daemon started on %s with PID %d", unit, pid) + return pid + + +def stop_continuous_writes( + juju: jubilant.Juju, unit: str = f"{GLIDE_RUNNER_NAME}/leader" +) -> ContinuousWritesStats: + """Trigger the stop-continuous-writes action and return write statistics. + + Args: + juju: Juju client instance. + unit: Unit name to run the action on. + + Returns: + ``ContinuousWritesStats`` with ``last_written_value`` (last integer + successfully written to Valkey) and ``total_count`` (number of items + in the list). + """ + result = juju.run(unit, "stop-continuous-writes") + assert result.results.get("ok"), f"stop-continuous-writes failed: {result}" + stats = ContinuousWritesStats( + last_written_value=int(result.results["last-written-value"]), + total_count=int(result.results["count"]), + ) + logger.info( + "Continuous-writes stopped on %s — last_written=%d, count=%d", + unit, + stats.last_written_value, + stats.total_count, + ) + return stats -async def assert_continuous_writes_increasing( - hostnames: list[str], - username: str, - password: str, - tls_enabled: bool = False, +def assert_continuous_writes_increasing( + juju: jubilant.Juju, + unit: str = f"{GLIDE_RUNNER_NAME}/leader", + wait: float = 10.0, ) -> None: - """Assert that the continuous writes are increasing.""" - async with create_valkey_client( - hostnames, - username=username, - password=password, - tls_enabled=tls_enabled, - ) as client: - writes_count = await client.llen(KEY) - await asyncio.sleep(10) - more_writes = await client.llen(KEY) - assert more_writes > writes_count, "Writes not continuing to DB" - logger.info("Continuous writes are increasing.") + """Run the assert-continuous-writes-increasing action on the requirer charm unit. + + Args: + juju: Juju client instance. + unit: Unit name to run the action on. + wait: Seconds to wait between state samples inside the charm. + """ + result = juju.run(unit, "assert-continuous-writes-increasing", {"wait": wait}) + assert result.status == "completed" and result.results.get("ok"), ( + f"assert-continuous-writes-increasing failed: {result}" + ) + logger.info( + "Continuous writes are increasing on %s (count %s -> %s)", + unit, + result.results.get("count-before"), + result.results.get("count-after"), + ) + + +def clear_continuous_writes(juju: jubilant.Juju, unit: str) -> None: + """Trigger the clear-continuous-writes action on the requirer charm unit. + + Deletes the continuous-writes key from Valkey. Can be called while the + daemon is stopped to reset data between test runs. + + Args: + juju: Juju client instance. + unit: Unit name to run the action on. + """ + result = juju.run(unit, "clear-continuous-writes") + assert result.results.get("ok"), f"clear-continuous-writes failed: {result}" + logger.info("Continuous-writes data cleared on %s", unit) def assert_continuous_writes_consistent( - hostnames: list[str], + endpoints: list[str], username: str, password: str, + last_written_value: int, tls_enabled: bool = False, ) -> None: - """Assert that the continuous writes are consistent.""" - last_written_value = int(Path(WRITES_LAST_WRITTEN_VAL_PATH).read_text()) - - if not last_written_value: - raise ValueError("Could not read last written value from file.") - - values: list[int] | None = None - - for endpoint in hostnames: + """Assert consistency of continuous-writes data across all Valkey instances. + + Checks two properties: + - The head of the list on every replica matches ``last_written_value``. + - Every replica holds an identical copy of the list. + + Args: + endpoints: List of Valkey endpoints to check. + username: Valkey username. + password: Valkey password. + last_written_value: Last integer successfully written, from ``stop_continuous_writes``. + tls_enabled: Whether to use TLS when connecting to Valkey. + """ + reference: list[int] | None = None + + for endpoint in endpoints: current_values: list[int] = json.loads( exec_valkey_cli( endpoint, @@ -92,13 +228,23 @@ def assert_continuous_writes_consistent( tls_enabled=tls_enabled, ).stdout ) - if values is None: - values = current_values last_value = int(current_values[0]) if current_values else None - assert last_written_value == last_value, ( - f"endpoint: {endpoint}, expected value: {last_written_value}, current value: {last_value}" + assert last_value == last_written_value, ( + f"endpoint {endpoint}: head of list is {last_value}, " + f"expected last_written_value={last_written_value}" ) - assert values == current_values, ( - f"endpoint: {endpoint}, expected values: {values}, current values: {current_values}" + + if reference is None: + reference = current_values + assert current_values == reference, ( + f"endpoint {endpoint}: list diverges from reference.\n" + f" reference (first endpoint): {reference[:10]}...\n" + f" this endpoint: {current_values[:10]}..." ) + + logger.info( + "Consistency check passed across %d endpoints (list len=%d).", + len(endpoints), + len(reference or []), + ) diff --git a/tests/integration/glide_helpers.py b/tests/integration/glide_helpers.py new file mode 100644 index 0000000..7de1391 --- /dev/null +++ b/tests/integration/glide_helpers.py @@ -0,0 +1,114 @@ +# Copyright 2026 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Serialization/deserialization helpers for GlideClientConfiguration objects. + +Converts a GlideClientConfiguration (and its nested objects) to/from a JSON +string so it can be passed as a Juju action parameter. + +Bytes fields are base64-encoded; enums are stored by name; nested Glide +objects are tagged with ``__class__`` for round-trip reconstruction. +""" + +import base64 +import json +from enum import Enum +from typing import Any + +from glide import ( + AdvancedGlideClientConfiguration, + BackoffStrategy, + GlideClientConfiguration, + NodeAddress, + ReadFrom, + ServerCredentials, + TlsAdvancedConfiguration, +) + +# Maps each Glide class to the set of fields that should be serialized. +# Tuple values like ``(list, NodeAddress)`` are documentation only — the +# serialize/deserialize logic recurses structurally, not via this type info. +SCHEMA: dict[type, dict[str, Any]] = { + GlideClientConfiguration: { + "addresses": (list, NodeAddress), + "use_tls": bool, + "request_timeout": (int, type(None)), + "read_from": ReadFrom, + "credentials": (ServerCredentials, type(None)), + "reconnect_strategy": (BackoffStrategy, type(None)), + "advanced_config": (AdvancedGlideClientConfiguration, type(None)), + }, + NodeAddress: { + "host": str, + "port": int, + }, + ServerCredentials: { + "username": (str, type(None)), + "password": (str, type(None)), + }, + BackoffStrategy: { + "num_of_retries": int, + "factor": int, + "exponent_base": int, + "jitter_percent": (int, type(None)), + }, + AdvancedGlideClientConfiguration: { + "connection_timeout": (int, type(None)), + "tls_config": (TlsAdvancedConfiguration, type(None)), + }, + TlsAdvancedConfiguration: { + "use_insecure_tls": bool, + "client_cert_pem": (bytes, type(None)), + "client_key_pem": (bytes, type(None)), + "root_pem_cacerts": (bytes, type(None)), + }, +} + +_GLIDE_CLASSES: dict[str, type] = {cls.__name__: cls for cls in SCHEMA} +_ENUM_CLASSES: dict[str, type[Enum]] = {"ReadFrom": ReadFrom} + + +def serialize(obj: Any) -> Any: + """Recursively serialize a Glide object to a JSON-compatible structure.""" + if obj is None: + return None + if isinstance(obj, bytes): + return {"__bytes__": base64.b64encode(obj).decode()} + if isinstance(obj, Enum): + return {"__enum__": type(obj).__name__, "value": obj.name} + if type(obj) in SCHEMA: + return { + "__class__": type(obj).__name__, + **{field: serialize(getattr(obj, field)) for field in SCHEMA[type(obj)]}, + } + if isinstance(obj, list): + return [serialize(i) for i in obj] + return obj # str, int, bool, None + + +def deserialize(d: Any) -> Any: + """Recursively deserialize a JSON-compatible structure back to Glide objects.""" + if d is None or not isinstance(d, (dict, list)): + return d + if isinstance(d, list): + return [deserialize(i) for i in d] + if "__bytes__" in d: + return base64.b64decode(d["__bytes__"]) + if "__enum__" in d: + cls = _ENUM_CLASSES[d["__enum__"]] + return cls[d["value"]] + if "__class__" in d: + cls = _GLIDE_CLASSES[d["__class__"]] + fields = {k: deserialize(v) for k, v in d.items() if k != "__class__"} + return cls(**fields) + return d + + +def serialize_glide_config(config: GlideClientConfiguration) -> str: + """Serialize a GlideClientConfiguration to a JSON string.""" + return json.dumps(serialize(config)) + + +def deserialize_glide_config(payload: str) -> GlideClientConfiguration: + """Deserialize a JSON string back to a GlideClientConfiguration.""" + return deserialize(json.loads(payload)) diff --git a/tests/integration/ha/test_failover.py b/tests/integration/ha/test_failover.py index 02d8430..c2e95e6 100644 --- a/tests/integration/ha/test_failover.py +++ b/tests/integration/ha/test_failover.py @@ -2,19 +2,21 @@ # Copyright 2025 Canonical Ltd. # See LICENSE file for licensing details. -import asyncio import json import logging +from time import sleep import jubilant import pytest from tenacity import Retrying, stop_after_attempt, wait_fixed from literals import CharmUsers, Substrate -from tests.integration.continuous_writes import ContinuousWrites from tests.integration.cw_helpers import ( assert_continuous_writes_consistent, assert_continuous_writes_increasing, + configure_cw_runner, + start_continuous_writes, + stop_continuous_writes, ) from tests.integration.ha.helpers.helpers import ( K8S_RESTART_DELAY_DEFAULT, @@ -28,6 +30,7 @@ from ..helpers import ( APP_NAME, + GLIDE_RUNNER_NAME, IMAGE_RESOURCE, TLS_CHANNEL, TLS_NAME, @@ -54,7 +57,11 @@ @pytest.mark.parametrize("tls_enabled", [False, True], ids=["tls_off", "tls_on"]) def test_build_and_deploy( - tls_enabled: bool, charm: str, juju: jubilant.Juju, substrate: Substrate + tls_enabled: bool, + charm: str, + juju: jubilant.Juju, + substrate: Substrate, + glide_runner_charm: str, ) -> None: """Build the charm-under-test and deploy it with three units.""" if app := existing_app(juju): @@ -68,12 +75,16 @@ def test_build_and_deploy( trust=True, ) + juju.deploy(glide_runner_charm, GLIDE_RUNNER_NAME) + if tls_enabled: juju.deploy(TLS_NAME, channel=TLS_CHANNEL) juju.integrate(f"{APP_NAME}:client-certificates", TLS_NAME) juju.wait( - lambda status: are_apps_active_and_agents_idle(status, APP_NAME, idle_period=30), + lambda status: are_apps_active_and_agents_idle( + status, APP_NAME, GLIDE_RUNNER_NAME, idle_period=30 + ), timeout=600, ) @@ -85,20 +96,24 @@ def test_build_and_deploy( @pytest.mark.parametrize("tls_enabled", [False, True], ids=["tls_off", "tls_on"]) @pytest.mark.parametrize("signal", ["SIGKILL", "SIGTERM"], ids=["sigkill", "sigterm"]) @pytest.mark.parametrize("patched_delay", [False, True], ids=["default_delay", "patched_delay"]) -async def test_signal_db_process_on_primary( +def test_signal_db_process_on_primary( tls_enabled: bool, signal: str, patched_delay: bool, juju: jubilant.Juju, substrate: Substrate, - c_writes: ContinuousWrites, - c_writes_async_clean, ) -> None: """Make sure the cluster can self-heal when the leader goes down.""" app_name = existing_app(juju) or APP_NAME if tls_enabled: download_client_certificate_from_unit(juju, APP_NAME) - c_writes.tls_enabled = tls_enabled + + configure_cw_runner( + juju, + valkey_app=app_name, + tls_enabled=tls_enabled, + substrate=substrate, + ) # make sure we have at least two units so we can stop one of them init_units_count = len(juju.status().get_units(app_name)) @@ -112,8 +127,8 @@ async def test_signal_db_process_on_primary( ) init_units_count = len(juju.status().get_units(app_name)) - c_writes.start() - await asyncio.sleep(10) + start_continuous_writes(juju, clear=True) + sleep(10) primary_ip = get_primary_ip(juju, app_name, tls_enabled=tls_enabled) assert primary_ip, "Failed to get primary endpoint from valkey." @@ -156,7 +171,7 @@ async def test_signal_db_process_on_primary( restart_delay += 10 # add some buffer to the restart delay logger.info("Waiting for primary unit to restart. Restart delay is %s seconds.", restart_delay) - await asyncio.sleep(restart_delay) + sleep(restart_delay) logger.info("Pinging primary unit to ensure it's up.") for attempt in Retrying(stop=stop_after_attempt(10), wait=wait_fixed(5), reraise=True): @@ -194,42 +209,44 @@ async def test_signal_db_process_on_primary( # if failover happened the old primary will need some time to restart and sync with the new primary before it shows up as a connected replica for attempt in Retrying(stop=stop_after_attempt(10), wait=wait_fixed(10), reraise=True): with attempt: - number_of_replicas = await get_number_connected_replicas( - addresses, CharmUsers.VALKEY_ADMIN, admin_password, tls_enabled=tls_enabled - ) + number_of_replicas = get_number_connected_replicas(juju, tls_enabled=tls_enabled) assert number_of_replicas == init_units_count - 1, ( f"Expected {init_units_count - 1} replicas to be connected after primary restart, got {number_of_replicas}" ) # ensure data is written in the cluster logger.info("Checking continuous writes are increasing after primary restart.") - await assert_continuous_writes_increasing( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN, - password=admin_password, - tls_enabled=tls_enabled, - ) + assert_continuous_writes_increasing(juju) - await c_writes.async_stop() + stats = stop_continuous_writes(juju) assert_continuous_writes_consistent( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN, - password=admin_password, + endpoints=get_cluster_addresses(juju, app_name), + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + last_written_value=stats.last_written_value, tls_enabled=tls_enabled, ) @pytest.mark.parametrize("tls_enabled", [False, True], ids=["tls_off", "tls_on"]) -async def test_freeze_db_process_on_primary( - tls_enabled: bool, juju: jubilant.Juju, substrate: Substrate, c_writes, c_writes_async_clean +def test_freeze_db_process_on_primary( + tls_enabled: bool, + juju: jubilant.Juju, + substrate: Substrate, ) -> None: """Make sure the cluster can self-heal when the leader goes down.""" app_name = existing_app(juju) or APP_NAME addresses = get_cluster_addresses(juju, app_name) if tls_enabled: download_client_certificate_from_unit(juju, APP_NAME) - c_writes.tls_enabled = tls_enabled + + configure_cw_runner( + juju, + valkey_app=app_name, + tls_enabled=tls_enabled, + substrate=substrate, + ) # make sure we have at least two units so we can stop one of them init_units_count = len(juju.status().get_units(app_name)) @@ -243,8 +260,8 @@ async def test_freeze_db_process_on_primary( ) init_units_count = len(juju.status().get_units(app_name)) - c_writes.start() - await asyncio.sleep(10) + start_continuous_writes(juju, clear=True) + sleep(10) primary_ip = get_primary_ip(juju, app_name, tls_enabled=tls_enabled) assert primary_ip, "Failed to get primary endpoint from valkey." @@ -269,7 +286,7 @@ async def test_freeze_db_process_on_primary( # ensure the stopped unit was restarted logger.info("Waiting for failover to happen.") - await asyncio.sleep(FAILOVER_DELAY) + sleep(FAILOVER_DELAY) new_primary_ip = get_primary_ip(juju, app_name, tls_enabled=tls_enabled) assert new_primary_ip != primary_ip, "Primary IP did not change after failover delay." @@ -279,19 +296,12 @@ async def test_freeze_db_process_on_primary( new_primary_hostname = f"{new_primary_unit_name.replace('/', '-')}.{app_name}-endpoints" new_primary_endpoint = new_primary_ip if substrate == Substrate.VM else new_primary_hostname - number_of_replicas = await get_number_connected_replicas( - addresses, CharmUsers.VALKEY_ADMIN, admin_password, tls_enabled=tls_enabled - ) + number_of_replicas = get_number_connected_replicas(juju, tls_enabled=tls_enabled) assert number_of_replicas == init_units_count - 2, ( f"Expected {init_units_count - 2} replicas to be connected, got {number_of_replicas}" ) - await assert_continuous_writes_increasing( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN, - password=admin_password, - tls_enabled=tls_enabled, - ) + assert_continuous_writes_increasing(juju) send_process_control_signal( unit_name=primary_unit_name, @@ -325,9 +335,7 @@ async def test_freeze_db_process_on_primary( logger.info("Old primary unit is available again.") logger.info("Checking number of connected replicas after primary restart.") - number_of_replicas = await get_number_connected_replicas( - addresses, CharmUsers.VALKEY_ADMIN, admin_password, tls_enabled=tls_enabled - ) + number_of_replicas = get_number_connected_replicas(juju, tls_enabled=tls_enabled) assert number_of_replicas == init_units_count - 1, ( f"Expected {init_units_count - 1} replicas to be connected after primary restart, got {number_of_replicas}" ) @@ -349,32 +357,34 @@ async def test_freeze_db_process_on_primary( # ensure data is written in the cluster logger.info("Checking continuous writes are increasing after primary restart.") - await assert_continuous_writes_increasing( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN, - password=admin_password, - tls_enabled=tls_enabled, - ) + assert_continuous_writes_increasing(juju) - await c_writes.async_stop() + stats = stop_continuous_writes(juju) assert_continuous_writes_consistent( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN, - password=admin_password, + endpoints=get_cluster_addresses(juju, app_name), + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + last_written_value=stats.last_written_value, tls_enabled=tls_enabled, ) @pytest.mark.parametrize("tls_enabled", [False, True], ids=["tls_off", "tls_on"]) -async def test_full_cluster_restart( - tls_enabled: bool, juju: jubilant.Juju, c_writes, c_writes_async_clean, substrate: Substrate +def test_full_cluster_restart( + tls_enabled: bool, juju: jubilant.Juju, substrate: Substrate ) -> None: """Make sure the cluster can self-heal after all members went down.""" app_name = existing_app(juju) or APP_NAME if tls_enabled: download_client_certificate_from_unit(juju, APP_NAME) - c_writes.tls_enabled = tls_enabled + + configure_cw_runner( + juju, + valkey_app=app_name, + tls_enabled=tls_enabled, + substrate=substrate, + ) # make sure we have at least two units so we can stop one of them init_units_count = len(juju.status().get_units(app_name)) @@ -388,8 +398,8 @@ async def test_full_cluster_restart( ) init_units_count = len(juju.status().get_units(app_name)) - c_writes.start() - await asyncio.sleep(10) + start_continuous_writes(juju, clear=True) + sleep(10) # update the restart delay for all units for unit in juju.status().get_units(app_name): @@ -420,7 +430,7 @@ async def test_full_cluster_restart( # ensure the stopped unit was restarted logger.info("Waiting for units to restart.") - await asyncio.sleep(RESTART_DELAY_PATCHED + 10) + sleep(RESTART_DELAY_PATCHED + 10) for unit, unit_info in juju.status().get_units(app_name).items(): unit_ip = unit_info.public_address if substrate == Substrate.VM else unit_info.address @@ -432,29 +442,23 @@ async def test_full_cluster_restart( logger.info("All units are available again.") logger.info("Checking number of connected replicas after primary restart.") - addresses = get_cluster_addresses(juju, app_name) - number_of_replicas = await get_number_connected_replicas( - addresses, CharmUsers.VALKEY_ADMIN, admin_password, tls_enabled=tls_enabled - ) + + number_of_replicas = get_number_connected_replicas(juju, tls_enabled=tls_enabled) assert number_of_replicas == init_units_count - 1, ( f"Expected {init_units_count - 1} replicas to be connected after primary restart, got {number_of_replicas}" ) # ensure data is written in the cluster logger.info("Checking continuous writes are increasing after primary restart.") - await assert_continuous_writes_increasing( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN, - password=admin_password, - tls_enabled=tls_enabled, - ) + assert_continuous_writes_increasing(juju) - await c_writes.async_stop() + stats = stop_continuous_writes(juju) assert_continuous_writes_consistent( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN, - password=admin_password, + endpoints=get_cluster_addresses(juju, app_name), + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + last_written_value=stats.last_written_value, tls_enabled=tls_enabled, ) @@ -469,14 +473,18 @@ async def test_full_cluster_restart( @pytest.mark.parametrize("tls_enabled", [False, True], ids=["tls_off", "tls_on"]) -async def test_full_cluster_crash( - tls_enabled: bool, juju: jubilant.Juju, c_writes, c_writes_async_clean, substrate: Substrate -) -> None: +def test_full_cluster_crash(tls_enabled: bool, juju: jubilant.Juju, substrate: Substrate) -> None: """Make sure the cluster can self-heal after all members went down.""" app_name = existing_app(juju) or APP_NAME if tls_enabled: download_client_certificate_from_unit(juju, APP_NAME) - c_writes.tls_enabled = tls_enabled + + configure_cw_runner( + juju, + valkey_app=app_name, + tls_enabled=tls_enabled, + substrate=substrate, + ) # make sure we have at least two units so we can stop one of them init_units_count = len(juju.status().get_units(app_name)) @@ -490,8 +498,8 @@ async def test_full_cluster_crash( ) init_units_count = len(juju.status().get_units(app_name)) - c_writes.start() - await asyncio.sleep(10) + start_continuous_writes(juju, clear=True) + sleep(10) # update the restart delay for all units for unit in juju.status().get_units(app_name): @@ -522,7 +530,7 @@ async def test_full_cluster_crash( # ensure the stopped unit was restarted logger.info("Waiting for units to restart.") - await asyncio.sleep(RESTART_DELAY_PATCHED + 10) + sleep(RESTART_DELAY_PATCHED + 10) for unit, unit_info in juju.status().get_units(app_name).items(): unit_ip = unit_info.public_address if substrate == Substrate.VM else unit_info.address @@ -534,29 +542,23 @@ async def test_full_cluster_crash( logger.info("All units are available again.") logger.info("Checking number of connected replicas after primary restart.") - addresses = get_cluster_addresses(juju, app_name) - number_of_replicas = await get_number_connected_replicas( - addresses, CharmUsers.VALKEY_ADMIN, admin_password, tls_enabled=tls_enabled - ) + + number_of_replicas = get_number_connected_replicas(juju, tls_enabled=tls_enabled) assert number_of_replicas == init_units_count - 1, ( f"Expected {init_units_count - 1} replicas to be connected after primary restart, got {number_of_replicas}" ) # ensure data is written in the cluster logger.info("Checking continuous writes are increasing after primary restart.") - await assert_continuous_writes_increasing( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN, - password=admin_password, - tls_enabled=tls_enabled, - ) + assert_continuous_writes_increasing(juju) - await c_writes.async_stop() + stats = stop_continuous_writes(juju) assert_continuous_writes_consistent( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN, - password=admin_password, + endpoints=get_cluster_addresses(juju, app_name), + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + last_written_value=stats.last_written_value, tls_enabled=tls_enabled, ) @@ -571,14 +573,18 @@ async def test_full_cluster_crash( @pytest.mark.parametrize("tls_enabled", [False, True], ids=["tls_off", "tls_on"]) -async def test_reboot_primary( - tls_enabled: bool, juju: jubilant.Juju, c_writes, c_writes_async_clean, substrate: Substrate -) -> None: +def test_reboot_primary(tls_enabled: bool, juju: jubilant.Juju, substrate: Substrate) -> None: """Make sure the cluster can self-heal when the leader goes down.""" app_name = existing_app(juju) or APP_NAME if tls_enabled: download_client_certificate_from_unit(juju, APP_NAME) - c_writes.tls_enabled = tls_enabled + + configure_cw_runner( + juju, + valkey_app=app_name, + tls_enabled=tls_enabled, + substrate=substrate, + ) # make sure we have at least two units so we can stop one of them init_units_count = len(juju.status().get_units(app_name)) @@ -592,9 +598,8 @@ async def test_reboot_primary( ) init_units_count = len(juju.status().get_units(app_name)) - await c_writes.async_clear() - c_writes.start() - await asyncio.sleep(10) + start_continuous_writes(juju, clear=True) + sleep(10) primary_ip = get_primary_ip(juju, app_name, tls_enabled=tls_enabled) assert primary_ip, "Failed to get primary endpoint from valkey." @@ -606,7 +611,7 @@ async def test_reboot_primary( reboot_unit(juju, primary_unit_name, substrate) # wait for unit to reboot - await asyncio.sleep(3) + sleep(3) # make sure the process is stopped admin_password = get_password(juju, CharmUsers.VALKEY_ADMIN) @@ -623,7 +628,12 @@ async def test_reboot_primary( timeout=1200, ) - c_writes.update() + configure_cw_runner( + juju, + valkey_app=app_name, + tls_enabled=tls_enabled, + substrate=substrate, + ) # on k8s we get a new ip new_ip = get_ip_from_unit(juju, primary_unit_name, substrate) @@ -631,42 +641,37 @@ async def test_reboot_primary( "Primary unit is not responding after reboot." ) - number_of_replicas = await get_number_connected_replicas( - get_cluster_addresses(juju, app_name), - CharmUsers.VALKEY_ADMIN, - admin_password, - tls_enabled=tls_enabled, - ) + number_of_replicas = get_number_connected_replicas(juju, tls_enabled=tls_enabled) assert number_of_replicas == init_units_count - 1, ( f"Expected {init_units_count - 1} replicas to be connected, got {number_of_replicas}" ) - await assert_continuous_writes_increasing( - hostnames=get_cluster_addresses(juju, app_name), - username=CharmUsers.VALKEY_ADMIN, - password=admin_password, - tls_enabled=tls_enabled, - ) + assert_continuous_writes_increasing(juju) - await c_writes.async_stop() + stats = stop_continuous_writes(juju) assert_continuous_writes_consistent( - hostnames=get_cluster_addresses(juju, app_name), - username=CharmUsers.VALKEY_ADMIN, - password=admin_password, + endpoints=get_cluster_addresses(juju, app_name), + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + last_written_value=stats.last_written_value, tls_enabled=tls_enabled, ) @pytest.mark.parametrize("tls_enabled", [False, True], ids=["tls_off", "tls_on"]) -async def test_full_cluster_reboot( - tls_enabled: bool, juju: jubilant.Juju, c_writes, c_writes_async_clean, substrate: Substrate -) -> None: +def test_full_cluster_reboot(tls_enabled: bool, juju: jubilant.Juju, substrate: Substrate) -> None: """Make sure the cluster can self-heal after all members went down.""" app_name = existing_app(juju) or APP_NAME if tls_enabled: download_client_certificate_from_unit(juju, APP_NAME) - c_writes.tls_enabled = tls_enabled + + configure_cw_runner( + juju, + valkey_app=app_name, + tls_enabled=tls_enabled, + substrate=substrate, + ) # make sure we have at least two units so we can stop one of them init_units_count = len(juju.status().get_units(app_name)) @@ -680,13 +685,13 @@ async def test_full_cluster_reboot( ) init_units_count = len(juju.status().get_units(app_name)) - c_writes.start() - await asyncio.sleep(10) + start_continuous_writes(juju, clear=True) + sleep(10) for unit in juju.status().get_units(app_name): reboot_unit(juju, unit, substrate) - await asyncio.sleep(3) + sleep(3) # make sure the process is stopped admin_password = get_password(juju, CharmUsers.VALKEY_ADMIN) @@ -706,7 +711,12 @@ async def test_full_cluster_reboot( timeout=1200, ) - c_writes.update() + configure_cw_runner( + juju, + valkey_app=app_name, + tls_enabled=tls_enabled, + substrate=substrate, + ) for unit, unit_info in juju.status().get_units(app_name).items(): unit_ip = unit_info.public_address if substrate == Substrate.VM else unit_info.address @@ -718,28 +728,22 @@ async def test_full_cluster_reboot( logger.info("All units are available again.") logger.info("Checking number of connected replicas after primary restart.") - addresses = get_cluster_addresses(juju, app_name) - number_of_replicas = await get_number_connected_replicas( - addresses, CharmUsers.VALKEY_ADMIN, admin_password, tls_enabled=tls_enabled - ) + + number_of_replicas = get_number_connected_replicas(juju, tls_enabled=tls_enabled) assert number_of_replicas == init_units_count - 1, ( f"Expected {init_units_count - 1} replicas to be connected after primary restart, got {number_of_replicas}" ) # ensure data is written in the cluster logger.info("Checking continuous writes are increasing after primary restart.") - await assert_continuous_writes_increasing( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN, - password=admin_password, - tls_enabled=tls_enabled, - ) + assert_continuous_writes_increasing(juju) - await c_writes.async_stop() + stats = stop_continuous_writes(juju) assert_continuous_writes_consistent( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN, - password=admin_password, + endpoints=get_cluster_addresses(juju, app_name), + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + last_written_value=stats.last_written_value, tls_enabled=tls_enabled, ) diff --git a/tests/integration/ha/test_network_cut.py b/tests/integration/ha/test_network_cut.py index 18cf820..658647a 100644 --- a/tests/integration/ha/test_network_cut.py +++ b/tests/integration/ha/test_network_cut.py @@ -10,6 +10,9 @@ from literals import Substrate from tests.integration.cw_helpers import ( assert_continuous_writes_increasing, + configure_cw_runner, + start_continuous_writes, + stop_continuous_writes, ) from tests.integration.ha.helpers.helpers import ( cut_network_from_unit, @@ -24,16 +27,15 @@ ) from tests.integration.helpers import ( APP_NAME, + GLIDE_RUNNER_NAME, IMAGE_RESOURCE, TLS_CHANNEL, TLS_NAME, - CharmUsers, are_apps_active_and_agents_idle, download_client_certificate_from_unit, get_cluster_addresses, get_ip_from_unit, get_number_connected_replicas, - get_password, get_primary_ip, ) @@ -44,7 +46,11 @@ @pytest.mark.parametrize("tls_enabled", [False, True], ids=["tls_off", "tls_on"]) def test_build_and_deploy( - tls_enabled: bool, charm: str, juju: jubilant.Juju, substrate: Substrate + tls_enabled: bool, + charm: str, + juju: jubilant.Juju, + substrate: Substrate, + glide_runner_charm: str, ) -> None: """Build the charm-under-test and deploy it with three units.""" juju.deploy( @@ -53,13 +59,16 @@ def test_build_and_deploy( num_units=NUM_UNITS, trust=True, ) + juju.deploy(glide_runner_charm, app=GLIDE_RUNNER_NAME) if tls_enabled: juju.deploy(TLS_NAME, channel=TLS_CHANNEL) juju.integrate(f"{APP_NAME}:client-certificates", TLS_NAME) juju.wait( - lambda status: are_apps_active_and_agents_idle(status, APP_NAME, idle_period=30), + lambda status: are_apps_active_and_agents_idle( + status, APP_NAME, GLIDE_RUNNER_NAME, idle_period=30 + ), timeout=600, ) @@ -70,14 +79,13 @@ def test_build_and_deploy( @pytest.mark.parametrize("tls_enabled", [False, True], ids=["tls_off", "tls_on"]) @pytest.mark.parametrize("ip_change", [True, False], ids=["ip_change", "no_ip_change"]) -async def test_network_cut_primary( # noqa: C901 +def test_network_cut_primary( # noqa: C901 tls_enabled: bool, ip_change: bool, juju: jubilant.Juju, substrate: Substrate, chaos_mesh, - c_writes, - c_writes_async_clean, + glide_runner, ) -> None: """Cut the network to the primary unit and verify that a new primary is elected.""" if ip_change and substrate == Substrate.K8S: @@ -86,9 +94,8 @@ async def test_network_cut_primary( # noqa: C901 download_client_certificate_from_unit(juju, APP_NAME) addresses = get_cluster_addresses(juju, APP_NAME) - c_writes.tls_enabled = tls_enabled - await c_writes.async_clear() - c_writes.start() + configure_cw_runner(juju, valkey_app=APP_NAME, tls_enabled=tls_enabled, substrate=substrate) + start_continuous_writes(juju, clear=True) # Get the current primary unit primary_ip = get_primary_ip(juju, APP_NAME, tls_enabled=tls_enabled) @@ -169,12 +176,7 @@ async def test_network_cut_primary( # noqa: C901 # retry in case cluster hasn't stabilized yet after primary cut and new primary election for attempt in Retrying(stop=stop_after_attempt(10), wait=wait_fixed(10), reraise=True): with attempt: - number_of_replicas = await get_number_connected_replicas( - addresses=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=tls_enabled, - ) + number_of_replicas = get_number_connected_replicas(juju, tls_enabled=tls_enabled) assert number_of_replicas == NUM_UNITS - 2, ( f"Expected {NUM_UNITS - 2} connected replicas, got {number_of_replicas}." ) @@ -195,12 +197,7 @@ async def test_network_cut_primary( # noqa: C901 f"The old primary endpoint should be marked as down in sentinels list of hostname {address} after network cut." ) - await assert_continuous_writes_increasing( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=tls_enabled, - ) + assert_continuous_writes_increasing(juju) # restore network to the original primary unit logger.info("Restoring network to original primary unit at %s", primary_hostname) @@ -215,7 +212,9 @@ async def test_network_cut_primary( # noqa: C901 ip_change=ip_change, unit_count=NUM_UNITS, ) - c_writes.update() + configure_cw_runner( + juju, valkey_app=APP_NAME, tls_enabled=tls_enabled, substrate=substrate + ) # update hostnames after network restore logger.info( "Verifying that all units can reach the original primary unit at %s...", @@ -257,12 +256,7 @@ async def test_network_cut_primary( # noqa: C901 # sometimes it takes some time for the old primary to be marked as replica and for sentinels to update their status, so we add a retry here for attempt in Retrying(stop=stop_after_attempt(10), wait=wait_fixed(10), reraise=True): with attempt: - number_of_replicas = await get_number_connected_replicas( - addresses=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=tls_enabled, - ) + number_of_replicas = get_number_connected_replicas(juju, tls_enabled=tls_enabled) assert number_of_replicas == NUM_UNITS - 1, ( f"Expected {NUM_UNITS - 1} connected replicas after network restoration, got {number_of_replicas}." ) @@ -288,9 +282,5 @@ async def test_network_cut_primary( # noqa: C901 f"The old primary endpoint should be present in sentinels list of hostname {address} after network cut and no IP change." ) - await assert_continuous_writes_increasing( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=tls_enabled, - ) + assert_continuous_writes_increasing(juju) + stop_continuous_writes(juju) diff --git a/tests/integration/ha/test_scaling.py b/tests/integration/ha/test_scaling.py index 64d144d..226a725 100644 --- a/tests/integration/ha/test_scaling.py +++ b/tests/integration/ha/test_scaling.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # Copyright 2026 Canonical Ltd. # See LICENSE file for licensing details. -import asyncio import logging +from time import sleep import jubilant import pytest @@ -11,9 +11,13 @@ from tests.integration.cw_helpers import ( assert_continuous_writes_consistent, assert_continuous_writes_increasing, + configure_cw_runner, + start_continuous_writes, + stop_continuous_writes, ) from tests.integration.helpers import ( APP_NAME, + GLIDE_RUNNER_NAME, IMAGE_RESOURCE, are_apps_active_and_agents_idle, existing_app, @@ -23,7 +27,6 @@ get_primary_ip, get_quorum, remove_number_units, - seed_valkey, ) logger = logging.getLogger(__name__) @@ -31,9 +34,12 @@ NUM_UNITS = 3 TEST_KEY = "test_key" TEST_VALUE = "test_value" +SEED_KEY_PREFIX = "seed:key:" -def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) -> None: +def test_build_and_deploy( + charm: str, juju: jubilant.Juju, substrate: Substrate, glide_runner_charm +) -> None: """Build the charm-under-test and deploy it with three units.""" if existing_app(juju): return @@ -44,8 +50,11 @@ def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) num_units=1, trust=True, ) + juju.deploy(glide_runner_charm, app=GLIDE_RUNNER_NAME) juju.wait( - lambda status: are_apps_active_and_agents_idle(status, APP_NAME, idle_period=30), + lambda status: are_apps_active_and_agents_idle( + status, APP_NAME, GLIDE_RUNNER_NAME, idle_period=30 + ), timeout=600, ) @@ -54,12 +63,22 @@ def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) ) -async def test_seed_data(juju: jubilant.Juju) -> None: +def test_seed_data(juju: jubilant.Juju, substrate: Substrate) -> None: """Seed some data to the cluster.""" - await seed_valkey(juju, target_gb=1) + configure_cw_runner(juju, substrate=substrate) + task = juju.run( + f"{GLIDE_RUNNER_NAME}/leader", + "seed-data", + params={ + "target-gb": 1.0, + "key-prefix": SEED_KEY_PREFIX, + }, + ) + if task.status != "completed": + logger.error(f"Data seeding failed: {task.results}") -async def test_check_quorum(juju: jubilant.Juju) -> None: +def test_check_quorum(juju: jubilant.Juju) -> None: """Check quorum value.""" app_name = existing_app(juju) or APP_NAME init_units_count = len(juju.status().apps[app_name].units) @@ -68,12 +87,12 @@ async def test_check_quorum(juju: jubilant.Juju) -> None: ) -async def test_scale_up(juju: jubilant.Juju, c_writes) -> None: +def test_scale_up(juju: jubilant.Juju, glide_runner, substrate: Substrate) -> None: """Make sure new units are added to the valkey downtime.""" app_name = existing_app(juju) or APP_NAME init_units_count = len(juju.status().apps[app_name].units) - await c_writes.async_clear() - c_writes.start() + configure_cw_runner(juju, valkey_app=app_name, substrate=substrate) + start_continuous_writes(juju, clear=True) # scale up juju.add_unit(app_name, num_units=2) @@ -96,31 +115,23 @@ async def test_scale_up(juju: jubilant.Juju, c_writes) -> None: # check if all units have been added to the cluster addresses = get_cluster_addresses(juju, app_name) - connected_replicas = await get_number_connected_replicas( - addresses=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - ) + connected_replicas = get_number_connected_replicas(juju) assert connected_replicas == init_units_count + 1, ( f"Expected {init_units_count + 1} connected replicas, got {connected_replicas}." ) - await assert_continuous_writes_increasing( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - ) + assert_continuous_writes_increasing(juju) logger.info("Stopping continuous writes after scale up test.") - logger.info(await c_writes.async_stop()) + cw_stats = stop_continuous_writes(juju) assert_continuous_writes_consistent( - hostnames=addresses, + endpoints=addresses, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + last_written_value=cw_stats.last_written_value, ) - await c_writes.async_clear() -async def test_scale_down_one_unit(juju: jubilant.Juju, substrate: Substrate, c_writes) -> None: +def test_scale_down_one_unit(juju: jubilant.Juju, substrate: Substrate, glide_runner) -> None: """Make sure scale down operations complete successfully.""" app_name = existing_app(juju) or APP_NAME init_units_count = len(juju.status().apps[app_name].units) @@ -135,18 +146,14 @@ async def test_scale_down_one_unit(juju: jubilant.Juju, substrate: Substrate, c_ timeout=1200, ) - number_of_replicas = await get_number_connected_replicas( - addresses=get_cluster_addresses(juju, app_name), - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - ) + number_of_replicas = get_number_connected_replicas(juju) assert number_of_replicas == init_units_count - 1, ( f"Expected {init_units_count - 1} connected replicas, got {number_of_replicas}." ) - await c_writes.async_clear() - c_writes.start() - await asyncio.sleep(10) # let the continuous writes write some data + configure_cw_runner(juju, valkey_app=app_name, substrate=substrate) + start_continuous_writes(juju, clear=True) + sleep(10) # let the continuous writes write some data # scale down remove_number_units(juju, app_name, num_units=1, substrate=substrate) @@ -165,37 +172,28 @@ async def test_scale_down_one_unit(juju: jubilant.Juju, substrate: Substrate, c_ f"Unexpected quorum value for unit {unit} after scale down" ) - number_of_replicas = await get_number_connected_replicas( - addresses=get_cluster_addresses(juju, app_name), - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - ) + number_of_replicas = get_number_connected_replicas(juju) assert number_of_replicas == init_units_count - 2, ( f"Expected {init_units_count - 2} connected replicas, got {number_of_replicas}." ) # update hostnames after scale down - c_writes.update() + configure_cw_runner(juju, valkey_app=app_name, substrate=substrate) - await assert_continuous_writes_increasing( - hostnames=get_cluster_addresses(juju, app_name), - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - ) + assert_continuous_writes_increasing(juju) logger.info("Stopping continuous writes after scale down test.") - logger.info(await c_writes.async_stop()) - + cw_stats = stop_continuous_writes(juju) assert_continuous_writes_consistent( - hostnames=get_cluster_addresses(juju, app_name), + endpoints=get_cluster_addresses(juju, app_name), username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + last_written_value=cw_stats.last_written_value, ) - await c_writes.async_clear() -async def test_scale_down_multiple_units( - juju: jubilant.Juju, substrate: Substrate, c_writes +def test_scale_down_multiple_units( + juju: jubilant.Juju, substrate: Substrate, glide_runner ) -> None: """Make sure multiple scale down operations complete successfully.""" app_name = existing_app(juju) or APP_NAME @@ -210,18 +208,15 @@ async def test_scale_down_multiple_units( ) init_units_count = NUM_UNITS + 1 - number_of_replicas = await get_number_connected_replicas( - addresses=get_cluster_addresses(juju, app_name), - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - ) + number_of_replicas = get_number_connected_replicas(juju) assert number_of_replicas == init_units_count - 1, ( f"Expected {init_units_count - 1} connected replicas, got {number_of_replicas}." ) - await c_writes.async_clear() - c_writes.start() - await asyncio.sleep(10) # let the continuous writes write some data + configure_cw_runner(juju, valkey_app=app_name, substrate=substrate) + start_continuous_writes(juju, clear=True) + + sleep(10) # let the continuous writes write some data # scale down multiple units remove_number_units(juju, app_name, num_units=2, substrate=substrate) @@ -236,11 +231,7 @@ async def test_scale_down_multiple_units( f"Expected {init_units_count - 2} units, got {num_units}." ) - number_of_replicas = await get_number_connected_replicas( - addresses=get_cluster_addresses(juju, app_name), - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - ) + number_of_replicas = get_number_connected_replicas(juju) assert number_of_replicas == init_units_count - 3, ( f"Expected {init_units_count - 3} connected replicas, got {number_of_replicas}." ) @@ -250,27 +241,25 @@ async def test_scale_down_multiple_units( f"Unexpected quorum value for unit {unit} after scale down" ) - c_writes.update() + configure_cw_runner( + juju, valkey_app=app_name, substrate=substrate + ) # update hostnames after scale down - await assert_continuous_writes_increasing( - hostnames=get_cluster_addresses(juju, app_name), - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - ) + assert_continuous_writes_increasing(juju) logger.info("Stopping continuous writes after scale down test.") - logger.info(await c_writes.async_stop()) + cw_stats = stop_continuous_writes(juju) assert_continuous_writes_consistent( - hostnames=get_cluster_addresses(juju, app_name), + endpoints=get_cluster_addresses(juju, app_name), username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + last_written_value=cw_stats.last_written_value, ) - await c_writes.async_clear() -async def test_scale_down_to_zero_and_back_up( - juju: jubilant.Juju, substrate: Substrate, c_writes +def test_scale_down_to_zero_and_back_up( + juju: jubilant.Juju, substrate: Substrate, glide_runner ) -> None: """Make sure that removing all units and then adding them again works.""" app_name = existing_app(juju) or APP_NAME @@ -292,33 +281,29 @@ async def test_scale_down_to_zero_and_back_up( addresses = get_cluster_addresses(juju, app_name) - connected_replicas = await get_number_connected_replicas( - addresses=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - ) + connected_replicas = get_number_connected_replicas(juju) assert connected_replicas == NUM_UNITS - 1, ( f"Expected {NUM_UNITS - 1} connected replicas, got {connected_replicas}." ) - await c_writes.async_clear() - c_writes.start() - await asyncio.sleep(10) # let the continuous writes write some data - await assert_continuous_writes_increasing( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - ) + + configure_cw_runner(juju, valkey_app=app_name, substrate=substrate) + start_continuous_writes(juju, clear=True) + + sleep(10) # let the continuous writes write some data + assert_continuous_writes_increasing(juju) + logger.info("Stopping continuous writes after scale up test.") - logger.info(await c_writes.async_stop()) + cw_stats = stop_continuous_writes(juju) + assert_continuous_writes_consistent( - hostnames=addresses, + endpoints=addresses, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + last_written_value=cw_stats.last_written_value, ) - await c_writes.async_clear() -async def test_scale_down_primary(juju: jubilant.Juju, substrate: Substrate, c_writes) -> None: +def test_scale_down_primary(juju: jubilant.Juju, substrate: Substrate, glide_runner) -> None: """Make sure that removing the primary unit triggers a new primary to be elected and the cluster remains available.""" if substrate == Substrate.K8S: pytest.skip("Primary unit can only targeted on VM") @@ -335,8 +320,10 @@ async def test_scale_down_primary(juju: jubilant.Juju, substrate: Substrate, c_w ) init_units_count = NUM_UNITS - await c_writes.async_clear() - c_writes.start() + configure_cw_runner(juju, valkey_app=app_name, substrate=substrate) + start_continuous_writes(juju, clear=True) + sleep(10) # let the continuous writes write some data + primary_endpoint = get_primary_ip(juju, app_name) primary_unit = next( unit @@ -355,26 +342,23 @@ async def test_scale_down_primary(juju: jubilant.Juju, substrate: Substrate, c_w status, app_name, unit_count=init_units_count - 1, idle_period=10 ) ) - c_writes.update() + configure_cw_runner( + juju, valkey_app=app_name, substrate=substrate + ) # update hostnames after primary unit removal new_primary_endpoint = get_primary_ip(juju, app_name) assert new_primary_endpoint != primary_endpoint, ( "Primary endpoint did not change after removing primary unit." ) logger.info(f"New primary endpoint after scale down is {new_primary_endpoint}.") - hostnames = get_cluster_addresses(juju, app_name) - await assert_continuous_writes_increasing( - hostnames=hostnames, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - ) - logger.info("Stopping continuous writes after primary scale down test.") - logger.info(await c_writes.async_stop()) + endpoints = get_cluster_addresses(juju, app_name) + assert_continuous_writes_increasing(juju) + cw_stats = stop_continuous_writes(juju) assert_continuous_writes_consistent( - hostnames=hostnames, + endpoints=endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + last_written_value=cw_stats.last_written_value, ) - await c_writes.async_clear() def test_scale_down_remove_application(juju: jubilant.Juju) -> None: diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index 1f1d963..647a98a 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -4,11 +4,9 @@ import json import logging -import os import re import subprocess -import time -from contextlib import asynccontextmanager, contextmanager +from contextlib import contextmanager from datetime import datetime, timedelta from pathlib import Path from typing import List, Literal, NamedTuple @@ -19,9 +17,7 @@ from dateutil.parser import parse from glide import ( AdvancedGlideClientConfiguration, - GlideClient, GlideClientConfiguration, - InfoSection, NodeAddress, ServerCredentials, TlsAdvancedConfiguration, @@ -39,17 +35,18 @@ CharmUsers, Substrate, ) +from tests.integration.glide_helpers import serialize_glide_config logger = logging.getLogger(__name__) METADATA = yaml.safe_load(Path("./metadata.yaml").read_text()) APP_NAME: str = METADATA["name"] +GLIDE_RUNNER_NAME = "glide-runner" IMAGE_RESOURCE = {"valkey-image": METADATA["resources"]["valkey-image"]["upstream-source"]} INTERNAL_USERS_SECRET_LABEL = ( f"{PEER_RELATION}.{APP_NAME}.app.{INTERNAL_USERS_SECRET_LABEL_SUFFIX}" ) -SEED_KEY_PREFIX = "seed:key:" TLS_NAME = "self-signed-certificates" TLS_CHANNEL = "1/edge" TLS_CERT_FILE = "client.pem" @@ -248,6 +245,27 @@ def get_cluster_addresses(juju: jubilant.Juju, app_name: str) -> list[str]: return [unit.public_address for unit in status.get_units(app_name).values()] +def get_cluster_endpoints(juju: jubilant.Juju, app_name: str) -> list[str]: + """Get the addresses of all units in the Valkey application. + + Args: + juju: The Juju client instance. + app_name: The name of the Valkey application. + + Returns: + A list of addresses for all units in the Valkey application. + """ + model_info = juju.show_model() + + if model_info.type == "kubernetes": + return [ + unit_name.replace("/", "-") + "." + app_name + "-endpoints" + for unit_name in juju.status().get_units(app_name) + ] + + return get_cluster_addresses(juju, app_name) + + def get_secret_by_label(juju: jubilant.Juju, label: str) -> dict[str, str]: for secret in juju.secrets(): if label == secret.label: @@ -257,32 +275,25 @@ def get_secret_by_label(juju: jubilant.Juju, label: str) -> dict[str, str]: raise SecretNotFoundError(f"Secret with label {label} not found") -@asynccontextmanager -async def create_valkey_client( - hostnames: list[str], +def get_glide_config( + juju: jubilant.Juju, + app_name: str, + endpoints: list[str] | None = None, username: str | None = CharmUsers.VALKEY_ADMIN.value, password: str | None = None, tls_enabled: bool = False, -): - """Create and return a Valkey client connected to the cluster. - - Args: - hostnames: List of hostnames of the Valkey cluster nodes. - username: The username for authentication. - password: The password for the internal user. - tls_enabled: Whether TLS certificates are needed. - - Returns: - A Valkey client instance connected to the cluster. - """ +) -> GlideClientConfiguration: + """Construct a GlideClientConfiguration from Juju model information and secrets.""" + endpoints = endpoints or get_cluster_endpoints(juju, app_name) addresses = [ - NodeAddress(host=host, port=TLS_PORT if tls_enabled else CLIENT_PORT) for host in hostnames + NodeAddress(host=host, port=TLS_PORT if tls_enabled else CLIENT_PORT) for host in endpoints ] credentials = None if username or password: credentials = ServerCredentials(username=username, password=password) + tls_cert = tls_key = tls_ca_cert = None if tls_enabled: # Read locally stored certificate files with open("client.pem", "rb") as f: @@ -296,10 +307,6 @@ async def create_valkey_client( client_cert_pem=tls_cert if tls_enabled else None, client_key_pem=tls_key if tls_enabled else None, root_pem_cacerts=tls_ca_cert if tls_enabled else None, - # We only set FQDN in the certs the IP is not in the cert - # so we need to skip hostname verification - # we cannot use the hostname because the runner cannot resolve it - use_insecure_tls=True if tls_enabled else None, ) client_config = GlideClientConfiguration( @@ -308,12 +315,7 @@ async def create_valkey_client( use_tls=True if tls_enabled else False, advanced_config=AdvancedGlideClientConfiguration(tls_config=tls_config), ) - - client = await GlideClient.create(client_config) - try: - yield client - finally: - await client.close() + return client_config def set_password( @@ -418,62 +420,6 @@ def get_password(juju: jubilant.Juju, user: CharmUsers = CharmUsers.VALKEY_ADMIN return secret.get(f"{user.value}-password", "") -async def seed_valkey(juju: jubilant.Juju, target_gb: float = 1.0) -> None: - # Connect to Valkey - addresses = get_cluster_addresses(juju, APP_NAME) - - # Configuration - value_size_bytes = 1024 # 1KB per value - batch_size = 5000 # Commands per pipeline - total_bytes_target = target_gb * 1024 * 1024 * 1024 - total_keys = total_bytes_target // value_size_bytes - - logger.info( - "Targeting ~%sGB (%s keys of %s bytes each)", - target_gb, - total_keys, - value_size_bytes, - ) - - start_time = time.time() - keys_added = 0 - - # Generate a fixed random block to reuse (saves CPU cycles on generation) - random_data = os.urandom(value_size_bytes).hex()[:value_size_bytes] - async with create_valkey_client(addresses, password=get_password(juju)) as client: - try: - while keys_added < total_keys: - data = { - f"{SEED_KEY_PREFIX}{key_idx}": random_data - for key_idx in range(keys_added, keys_added + batch_size) - } - - if await client.mset(data) != "OK": - raise RuntimeError("Failed to set data in Valkey cluster") - - keys_added += batch_size - - # Progress reporting - elapsed = time.time() - start_time - percent = (keys_added / total_keys) * 100 - logger.info( - "Progress: %.1f%% | Keys: %s | Elapsed: %.1f s", - percent, - keys_added, - elapsed, - ) - - except Exception as e: - logger.error("Error: %s", e) - finally: - total_time = time.time() - start_time - logger.info( - "Seeding complete! Added %s keys in %.2f seconds.", - keys_added, - total_time, - ) - - valkey_cli_result = NamedTuple( "ValkeyCliResult", [("stdout", str), ("stderr", str), ("returncode", int)] ) @@ -530,32 +476,47 @@ def get_quorum(juju: jubilant.Juju, unit_name: str) -> int: return int(json.loads(result.stdout)["quorum"]) -async def set_key( - hostnames: list[str], +def set_key( + juju: jubilant.Juju, + endpoints: list[str], username: str, password: str, key: str, value: str, tls_enabled: bool = False, -) -> bytes | None: +) -> str: """Write a key-value pair to the Valkey cluster. Args: - hostnames: List of hostnames of the Valkey cluster nodes. - key: The key to write. - value: The value to write. + juju: An instance of Jubilant's Juju class on which to run Juju commands + endpoints: List of endpoints of the Valkey cluster nodes. username: The username for authentication. password: The password for authentication. + key: The key to set. + value: The value to set. tls_enabled: Whether TLS certificates are needed. """ - async with create_valkey_client( - hostnames=hostnames, username=username, password=password, tls_enabled=tls_enabled - ) as client: - return await client.set(key, value) + glide_config = get_glide_config( + juju=juju, + app_name=APP_NAME, + endpoints=endpoints, + username=username, + password=password, + tls_enabled=tls_enabled, + ) + task = juju.run( + f"{GLIDE_RUNNER_NAME}/leader", + "execute", + params={"command": f"SET {key} {value}", "config": serialize_glide_config(glide_config)}, + ) + if task.status != "completed": + raise RuntimeError(f"Command execution failed: {task.results}") + return json.loads(task.results.get("result", "null")) -async def get_key( - hostnames: list[str], +def get_key( + juju: jubilant.Juju, + endpoints: list[str], username: str, password: str, key: str, @@ -564,16 +525,29 @@ async def get_key( """Read a value from the Valkey cluster by key. Args: - hostnames: List of hostnames of the Valkey cluster nodes. + juju: An instance of Jubilant's Juju class on which to run Juju commands + endpoints: List of endpoints of the Valkey cluster nodes. key: The key to read. username: The username for authentication. password: The password for authentication. tls_enabled: Whether TLS certificates are needed. """ - async with create_valkey_client( - hostnames=hostnames, username=username, password=password, tls_enabled=tls_enabled - ) as client: - return await client.get(key) + glide_config = get_glide_config( + juju=juju, + app_name=APP_NAME, + endpoints=endpoints, + username=username, + password=password, + tls_enabled=tls_enabled, + ) + task = juju.run( + f"{GLIDE_RUNNER_NAME}/leader", + "execute", + params={"command": f"GET {key}", "config": serialize_glide_config(glide_config)}, + ) + if task.status != "completed": + raise RuntimeError(f"Command execution failed: {task.results}") + return json.loads(task.results.get("result", "null")) def ping( @@ -603,8 +577,10 @@ def ping( return False -async def ping_cluster( - hostnames: list[str], +def ping_cluster( + juju: jubilant.Juju, + app_name: str, + endpoints: list[str], username: str, password: str, tls_enabled: bool = False, @@ -612,7 +588,9 @@ async def ping_cluster( """Ping all nodes in the Valkey cluster. Args: - hostnames: List of hostnames of the Valkey cluster nodes. + juju: An instance of Jubilant's Juju class on which to run Juju commands + app_name: The name of the Valkey application + endpoints: List of endpoints of the Valkey cluster nodes. username: The username for authentication. password: The password for authentication. tls_enabled: Whether TLS certificates are needed. @@ -620,37 +598,51 @@ async def ping_cluster( Returns: True if all nodes respond to a ping, False otherwise. """ - async with create_valkey_client( - hostnames=hostnames, username=username, password=password, tls_enabled=tls_enabled - ) as client: - return await client.ping() == "PONG".encode() + glide_config = get_glide_config( + juju=juju, + app_name=app_name, + endpoints=endpoints, + username=username, + password=password, + tls_enabled=tls_enabled, + ) + task = juju.run( + f"{GLIDE_RUNNER_NAME}/leader", + "execute", + params={"command": "ping", "config": serialize_glide_config(glide_config)}, + ) + return task.status == "completed" and json.loads(task.results.get("result", "")) == "PONG" -async def get_number_connected_replicas( - addresses: list[str], - username: str, - password: str, +def get_number_connected_replicas( + juju: jubilant.Juju, + glide_runner_unit: str = f"{GLIDE_RUNNER_NAME}/leader", tls_enabled: bool = False, ) -> int: """Get the number of connected replicas in the Valkey cluster. Args: - addresses: List of addresses of the Valkey cluster nodes. - username: The username for authentication. - password: The password for authentication. + juju: An instance of Jubilant's Juju class on which to run Juju commands + glide_runner_unit: The unit name of the glide-runner to execute the command on tls_enabled: Whether TLS certificates are needed. Returns: The number of connected replicas. """ - async with create_valkey_client( - hostnames=addresses, - username=username, - password=password, + glide_config = get_glide_config( + juju=juju, + app_name=APP_NAME, + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju), tls_enabled=tls_enabled, - ) as client: - info = (await client.info([InfoSection.REPLICATION])).decode() - search_result = re.search(r"connected_slaves:([\d+])", info) + ) + task_result = juju.run( + glide_runner_unit, + "execute", + {"command": "info replication", "config": serialize_glide_config(glide_config)}, + ) + assert task_result.status == "completed", f"Command execution failed: {task_result.results}" + search_result = re.search(r"connected_slaves:([\d+])", task_result.results.get("result", "")) if not search_result: raise ValueError("Could not parse number of connected replicas from info output") return int(search_result.group(1)) @@ -664,33 +656,46 @@ class WrongPassError(Exception): """Raised when authentication fails due to incorrect credentials.""" -async def auth_test( - hostnames: list[str], username: str | None, password: str | None, tls_enabled: bool = False +def auth_test( + juju: jubilant.Juju, + endpoints: list[str] | None = None, + username: str | None = None, + password: str | None = None, + tls_enabled: bool = False, + glide_runner_unit: str = f"{GLIDE_RUNNER_NAME}/leader", ) -> bool: """Test authentication to the Valkey cluster by attempting to ping it. Args: - hostnames: List of hostnames of the Valkey cluster nodes. + juju: An instance of Jubilant's Juju class on which to run Juju commands + endpoints: List of endpoints of the Valkey cluster nodes. If None, will be retrieved from Juju. username: The username for authentication. password: The password for authentication. tls_enabled: Whether TLS certificates are needed. + glide_runner_unit: The unit name of the glide-runner to execute the command on Returns: True if authentication is successful and the cluster responds to a ping, False otherwise. """ - try: - async with create_valkey_client( - hostnames=hostnames, username=username, password=password, tls_enabled=tls_enabled - ) as client: - return await client.ping() == "PONG".encode() - except Exception as e: - error_message = str(e) - if "NOAUTH" in error_message: - raise NoAuthError("Authentication failed: NOAUTH error") from e - elif "WRONGPASS" in error_message: - raise WrongPassError("Authentication failed: WRONGPASS error") from e - else: - raise e + glide_config = get_glide_config( + juju=juju, + endpoints=endpoints, + app_name=APP_NAME, + username=username, + password=password, + tls_enabled=tls_enabled, + ) + task = juju.run( + glide_runner_unit, + "execute", + params={"command": "ping", "config": serialize_glide_config(glide_config)}, + ) + result = json.loads(task.results.get("result", "")) + if "NOAUTH" in result: + raise NoAuthError("Authentication failed: NOAUTH error") + elif "WRONGPASS" in result: + raise WrongPassError("Authentication failed: WRONGPASS error") + return task.status == "completed" and result == "PONG" def remove_number_units( diff --git a/tests/integration/test_charm.py b/tests/integration/test_charm.py index e80ce6e..19766a5 100644 --- a/tests/integration/test_charm.py +++ b/tests/integration/test_charm.py @@ -14,6 +14,7 @@ from statuses import CharmStatuses, ClusterStatuses from tests.integration.helpers import ( APP_NAME, + GLIDE_RUNNER_NAME, IMAGE_RESOURCE, INTERNAL_USERS_SECRET_LABEL, NoAuthError, @@ -24,6 +25,7 @@ exec_valkey_cli, fast_forward, get_cluster_addresses, + get_cluster_endpoints, get_password, get_secret_by_label, ping, @@ -39,7 +41,9 @@ TEST_VALUE = "test_value" -def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) -> None: +def test_build_and_deploy( + charm: str, juju: jubilant.Juju, substrate: Substrate, glide_runner_charm: str +) -> None: """Build the charm-under-test and deploy it with three units.""" juju.deploy( charm, @@ -47,21 +51,25 @@ def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) num_units=NUM_UNITS, trust=True, ) + juju.deploy(glide_runner_charm, app=GLIDE_RUNNER_NAME) juju.wait( - lambda status: are_apps_active_and_agents_idle(status, APP_NAME, idle_period=30), + lambda status: are_apps_active_and_agents_idle( + status, APP_NAME, GLIDE_RUNNER_NAME, idle_period=30 + ), timeout=600, delay=5, successes=3, ) -async def test_authentication(juju: jubilant.Juju) -> None: +def test_authentication(juju: jubilant.Juju) -> None: """Assert that we can authenticate to valkey.""" addresses = get_cluster_addresses(juju, APP_NAME) + endpoints = get_cluster_endpoints(juju, APP_NAME) # try without authentication with pytest.raises(NoAuthError): - await auth_test(addresses, username=None, password=None) + auth_test(juju, endpoints=endpoints, username=None, password=None) # Authenticate with internal user password = get_password(juju, user=CharmUsers.VALKEY_ADMIN) @@ -74,7 +82,7 @@ async def test_authentication(juju: jubilant.Juju) -> None: ), "Failed to authenticate with Valkey cluster using CLI" -async def test_update_admin_password(juju: jubilant.Juju) -> None: +def test_update_admin_password(juju: jubilant.Juju) -> None: """Assert the admin password is updated when adding a user secret to the config.""" # create a user secret and grant it to the application logger.info("Updating operator password") @@ -91,21 +99,25 @@ async def test_update_admin_password(juju: jubilant.Juju) -> None: new_password_secret = get_password(juju, user=CharmUsers.VALKEY_ADMIN) assert new_password_secret == new_password, "Admin password not updated in secret" - addresses = get_cluster_addresses(juju, APP_NAME) + endpoints = get_cluster_endpoints(juju, APP_NAME) # confirm old password no longer works with pytest.raises(WrongPassError): - await auth_test(addresses, username=CharmUsers.VALKEY_ADMIN.value, password=old_password) + auth_test( + juju, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=old_password, + ) assert ( - await ping_cluster( - addresses, username=CharmUsers.VALKEY_ADMIN.value, password=new_password - ) + ping_cluster(juju, APP_NAME, endpoints, CharmUsers.VALKEY_ADMIN.value, new_password) is True ), "Failed to authenticate with new admin password" assert ( - await set_key( - addresses, + set_key( + juju, + endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=new_password, key=TEST_KEY, @@ -138,7 +150,7 @@ async def test_update_admin_password(juju: jubilant.Juju) -> None: ), f"Failed to read data after admin password update on host {address}" -async def test_update_admin_password_wrong_username(juju: jubilant.Juju) -> None: +def test_update_admin_password_wrong_username(juju: jubilant.Juju) -> None: """Assert the admin password is updated when adding a user secret to the config.""" # create a user secret and grant it to the application secret = get_secret_by_label(juju, label=INTERNAL_USERS_SECRET_LABEL) @@ -170,8 +182,10 @@ async def test_update_admin_password_wrong_username(juju: jubilant.Juju) -> None # perform read operation with the updated password assert ( - await ping_cluster( - get_cluster_addresses(juju, APP_NAME), + ping_cluster( + juju=juju, + app_name=APP_NAME, + endpoints=get_cluster_endpoints(juju, APP_NAME), username=CharmUsers.VALKEY_ADMIN.value, password=new_password, ) @@ -179,8 +193,9 @@ async def test_update_admin_password_wrong_username(juju: jubilant.Juju) -> None ), "Failed to authenticate with new admin password" assert ( - await set_key( - get_cluster_addresses(juju, APP_NAME), + set_key( + juju=juju, + endpoints=get_cluster_endpoints(juju, APP_NAME), username=CharmUsers.VALKEY_ADMIN.value, password=new_password, key=TEST_KEY, @@ -199,7 +214,7 @@ async def test_update_admin_password_wrong_username(juju: jubilant.Juju) -> None ) -async def test_user_secret_permissions(juju: jubilant.Juju) -> None: +def test_user_secret_permissions(juju: jubilant.Juju) -> None: """If a user secret is not granted, ensure we can process updated permissions.""" logger.info("Creating new user secret") secret_name = "my_secret" @@ -229,14 +244,19 @@ async def test_user_secret_permissions(juju: jubilant.Juju) -> None: ) # perform read operation with the updated password - addresses = get_cluster_addresses(juju, APP_NAME) - assert await ping_cluster( - addresses, username=CharmUsers.VALKEY_ADMIN.value, password=new_password + endpoints = get_cluster_endpoints(juju, APP_NAME) + assert ping_cluster( + juju=juju, + app_name=APP_NAME, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=new_password, ), "Failed to authenticate with new admin password" assert ( - await set_key( - addresses, + set_key( + juju=juju, + endpoints=endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=new_password, key=TEST_KEY, @@ -245,7 +265,7 @@ async def test_user_secret_permissions(juju: jubilant.Juju) -> None: == "OK" ), "Failed to write data after admin password update" - for address in addresses: + for address in get_cluster_addresses(juju, APP_NAME): assert ( ping(address, username=CharmUsers.VALKEY_ADMIN.value, password=new_password) is True ), ( diff --git a/tests/integration/tls/test_certificate_options.py b/tests/integration/tls/test_certificate_options.py index 9d9776b..b43e8f7 100644 --- a/tests/integration/tls/test_certificate_options.py +++ b/tests/integration/tls/test_certificate_options.py @@ -13,6 +13,7 @@ from statuses import TLSStatuses from tests.integration.helpers import ( APP_NAME, + GLIDE_RUNNER_NAME, IMAGE_RESOURCE, TLS_CERT_FILE, TLS_CHANNEL, @@ -21,7 +22,7 @@ are_apps_active_and_agents_idle, does_status_match, download_client_certificate_from_unit, - get_cluster_addresses, + get_cluster_endpoints, get_password, set_key, ) @@ -34,7 +35,9 @@ VAULT_NAME = "vault" -def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) -> None: +def test_build_and_deploy( + charm: str, juju: jubilant.Juju, substrate: Substrate, glide_runner_charm +) -> None: """Deploy the charm under test and a TLS provider.""" logger.info("Installing vault cli client") subprocess.run( @@ -47,6 +50,7 @@ def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) num_units=NUM_UNITS, trust=True, ) + juju.deploy(glide_runner_charm, app=GLIDE_RUNNER_NAME) juju.deploy(TLS_NAME, channel=TLS_CHANNEL) juju.deploy( "vault-k8s" if substrate == Substrate.K8S else "vault", @@ -60,7 +64,13 @@ def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) ) juju.integrate(f"{APP_NAME}:client-certificates", TLS_NAME) juju.wait( - lambda status: are_agents_idle(status, APP_NAME, idle_period=30, unit_count=NUM_UNITS), + lambda status: are_agents_idle( + status, + APP_NAME, + GLIDE_RUNNER_NAME, + idle_period=30, + unit_count={APP_NAME: NUM_UNITS, GLIDE_RUNNER_NAME: 1}, + ), timeout=600, ) juju.wait(lambda status: jubilant.all_blocked(status, VAULT_NAME)) @@ -240,7 +250,7 @@ def test_initialize_vault(juju: jubilant.Juju, substrate: Substrate) -> None: juju.wait(lambda status: are_apps_active_and_agents_idle(status, VAULT_NAME)) -async def test_certificate_denied(juju: jubilant.Juju) -> None: +def test_certificate_denied(juju: jubilant.Juju) -> None: """Process denied certificate request.""" logger.info("Integrate %s with %s for Intermediate CA", VAULT_NAME, TLS_NAME) juju.integrate(f"{VAULT_NAME}:tls-certificates-pki", TLS_NAME) @@ -259,9 +269,10 @@ async def test_certificate_denied(juju: jubilant.Juju) -> None: ) logger.info("Ensure access without TLS is still possible") - addresses = get_cluster_addresses(juju, APP_NAME) - result = await set_key( - hostnames=addresses, + endpoints = get_cluster_endpoints(juju, APP_NAME) + result = set_key( + juju=juju, + endpoints=endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), tls_enabled=False, diff --git a/tests/integration/tls/test_certificate_rotation.py b/tests/integration/tls/test_certificate_rotation.py index 8b76b5b..08341eb 100644 --- a/tests/integration/tls/test_certificate_rotation.py +++ b/tests/integration/tls/test_certificate_rotation.py @@ -5,12 +5,12 @@ from time import sleep import jubilant -import pytest from literals import CharmUsers, Substrate from statuses import TLSStatuses from tests.integration.helpers import ( APP_NAME, + GLIDE_RUNNER_NAME, IMAGE_RESOURCE, TLS_CA_FILE, TLS_CERT_FILE, @@ -20,7 +20,7 @@ auth_test, does_status_match, download_client_certificate_from_unit, - get_cluster_addresses, + get_cluster_endpoints, get_key, get_password, set_key, @@ -48,7 +48,9 @@ def _prepare_units_for_ca_expiration_test(juju: jubilant.Juju) -> None: ) -def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) -> None: +def test_build_and_deploy( + charm: str, juju: jubilant.Juju, substrate: Substrate, glide_runner_charm: str +) -> None: """Deploy the charm under test and a TLS provider.""" juju.deploy( charm, @@ -56,16 +58,26 @@ def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) num_units=NUM_UNITS, trust=True, ) + juju.deploy(glide_runner_charm, app=GLIDE_RUNNER_NAME) tls_config = {"certificate-validity": "6m", "ca-common-name": "valkey"} juju.deploy(TLS_NAME, channel=TLS_CHANNEL, config=tls_config) juju.wait( - lambda status: are_agents_idle(status, APP_NAME, idle_period=30, unit_count=NUM_UNITS), + lambda status: are_agents_idle( + status, + APP_NAME, + GLIDE_RUNNER_NAME, + idle_period=30, + unit_count={ + APP_NAME: NUM_UNITS, + GLIDE_RUNNER_NAME: 1, + }, + ), timeout=600, ) -async def test_certificate_expiration(juju: jubilant.Juju) -> None: +def test_certificate_expiration(juju: jubilant.Juju) -> None: """Test the TLS certificate expiration and renewal on a running cluster.""" _prepare_units_for_ca_expiration_test(juju) @@ -80,9 +92,10 @@ async def test_certificate_expiration(juju: jubilant.Juju) -> None: download_client_certificate_from_unit(juju, APP_NAME) logger.info("Check access with TLS enabled") - addresses = get_cluster_addresses(juju, APP_NAME) - result = await set_key( - hostnames=addresses, + endpoints = get_cluster_endpoints(juju, APP_NAME) + result = set_key( + juju=juju, + endpoints=endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), tls_enabled=True, @@ -91,13 +104,17 @@ async def test_certificate_expiration(juju: jubilant.Juju) -> None: ) assert result == "OK", "Failed to write data with TLS enabled" - assert await get_key( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=True, - key=TEST_KEY, - ) == bytes(TEST_VALUE, "utf-8"), "Failed to read data with TLS enabled" + assert ( + get_key( + juju=juju, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + tls_enabled=True, + key=TEST_KEY, + ) + == TEST_VALUE + ), "Failed to read data with TLS enabled" logger.info("Store current certificate before expiration") with open(TLS_CERT_FILE, "r") as file: @@ -108,15 +125,12 @@ async def test_certificate_expiration(juju: jubilant.Juju) -> None: sleep(CERTIFICATE_EXPIRY_TIME) logger.info("Check access with previous certificate fails after expiration") - with pytest.raises(Exception) as exc_info: - await auth_test( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=True, - ) - assert "Connection error" in str(exc_info.value), ( - "Access with expired certificate did not fail as expected" + assert not auth_test( + juju=juju, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + tls_enabled=True, ) logger.info("Store new certificate after rotation") @@ -130,8 +144,9 @@ async def test_certificate_expiration(juju: jubilant.Juju) -> None: logger.info("Check access with updated certificate") download_client_certificate_from_unit(juju, APP_NAME) - result = await set_key( - hostnames=addresses, + result = set_key( + juju=juju, + endpoints=endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), tls_enabled=True, @@ -140,13 +155,17 @@ async def test_certificate_expiration(juju: jubilant.Juju) -> None: ) assert result == "OK", "Failed to write data with updated certificate" - assert await get_key( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=True, - key=TEST_KEY, - ) == bytes(TEST_VALUE, "utf-8"), "Failed to read data with updated certificate" + assert ( + get_key( + juju=juju, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + tls_enabled=True, + key=TEST_KEY, + ) + == TEST_VALUE + ), "Failed to read data with updated certificate" juju.wait( lambda status: does_status_match( @@ -158,7 +177,7 @@ async def test_certificate_expiration(juju: jubilant.Juju) -> None: ) -async def test_ca_rotation_by_config_change(juju: jubilant.Juju) -> None: +def test_ca_rotation_by_config_change(juju: jubilant.Juju) -> None: """Test the CA rotation. The CA certificate should be rotated and the cluster should still be accessible. @@ -194,9 +213,10 @@ async def test_ca_rotation_by_config_change(juju: jubilant.Juju) -> None: assert old_certificate != new_certificate, "Certificate was not updated" logger.info("Check access with updated certificate") - addresses = get_cluster_addresses(juju, APP_NAME) - result = await set_key( - hostnames=addresses, + endpoints = get_cluster_endpoints(juju, APP_NAME) + result = set_key( + juju=juju, + endpoints=endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), tls_enabled=True, @@ -205,16 +225,20 @@ async def test_ca_rotation_by_config_change(juju: jubilant.Juju) -> None: ) assert result == "OK", "Failed to write data with updated certificate" - assert await get_key( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=True, - key=TEST_KEY, - ) == bytes(TEST_VALUE, "utf-8"), "Failed to read data with updated certificate" + assert ( + get_key( + juju=juju, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + tls_enabled=True, + key=TEST_KEY, + ) + == TEST_VALUE + ), "Failed to read data with updated certificate" -async def test_ca_rotation_by_expiration(juju: jubilant.Juju) -> None: +def test_ca_rotation_by_expiration(juju: jubilant.Juju) -> None: """Test the CA rotation. The CA certificate should be rotated and the cluster should still be accessible. @@ -245,9 +269,10 @@ async def test_ca_rotation_by_expiration(juju: jubilant.Juju) -> None: assert old_certificate, "Failed to get current certificate" logger.info("Check access with current TLS certificate") - addresses = get_cluster_addresses(juju, APP_NAME) - result = await set_key( - hostnames=addresses, + endpoints = get_cluster_endpoints(juju, APP_NAME) + result = set_key( + juju=juju, + endpoints=endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), tls_enabled=True, @@ -256,13 +281,17 @@ async def test_ca_rotation_by_expiration(juju: jubilant.Juju) -> None: ) assert result == "OK", "Failed to write data with TLS enabled" - assert await get_key( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=True, - key=TEST_KEY, - ) == bytes(TEST_VALUE, "utf-8"), "Failed to read data with TLS enabled" + assert ( + get_key( + juju=juju, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + tls_enabled=True, + key=TEST_KEY, + ) + == TEST_VALUE + ), "Failed to read data with TLS enabled" logger.info("Waiting for CA certificate to expire") sleep(CA_EXPIRY_TIME) @@ -272,17 +301,13 @@ async def test_ca_rotation_by_expiration(juju: jubilant.Juju) -> None: ) logger.info("Check access with previous certificate fails after expiration") - with pytest.raises(Exception) as exc_info: - await auth_test( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=True, - ) - assert "Connection error" in str(exc_info.value), ( - "Access with expired certificate did not fail as expected" + assert not auth_test( + juju=juju, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + tls_enabled=True, ) - logger.info("Store new certificate after rotation") download_client_certificate_from_unit(juju, APP_NAME) with open(TLS_CA_FILE, "r") as ca_file: @@ -295,9 +320,10 @@ async def test_ca_rotation_by_expiration(juju: jubilant.Juju) -> None: assert old_certificate != new_certificate, "Certificate was not updated" logger.info("Check access with updated certificate") - addresses = get_cluster_addresses(juju, APP_NAME) - result = await set_key( - hostnames=addresses, + endpoints = get_cluster_endpoints(juju, APP_NAME) + result = set_key( + juju=juju, + endpoints=endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), tls_enabled=True, @@ -306,10 +332,14 @@ async def test_ca_rotation_by_expiration(juju: jubilant.Juju) -> None: ) assert result == "OK", "Failed to write data with updated certificate" - assert await get_key( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=True, - key=TEST_KEY, - ) == bytes(TEST_VALUE, "utf-8"), "Failed to read data with updated certificate" + assert ( + get_key( + juju=juju, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + tls_enabled=True, + key=TEST_KEY, + ) + == TEST_VALUE + ), "Failed to read data with updated certificate" diff --git a/tests/integration/tls/test_private_key.py b/tests/integration/tls/test_private_key.py index 435f202..fcf9ab7 100644 --- a/tests/integration/tls/test_private_key.py +++ b/tests/integration/tls/test_private_key.py @@ -10,6 +10,7 @@ from statuses import TLSStatuses from tests.integration.helpers import ( APP_NAME, + GLIDE_RUNNER_NAME, IMAGE_RESOURCE, TLS_CERT_FILE, TLS_CHANNEL, @@ -18,7 +19,7 @@ are_agents_idle, does_status_match, download_client_certificate_from_unit, - get_cluster_addresses, + get_cluster_endpoints, get_key, get_password, set_key, @@ -31,7 +32,9 @@ TEST_VALUE = "test_value" -def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) -> None: +def test_build_and_deploy( + charm: str, juju: jubilant.Juju, substrate: Substrate, glide_runner_charm: str +) -> None: """Deploy the charm under test and a TLS provider.""" juju.deploy( charm, @@ -39,10 +42,19 @@ def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) num_units=NUM_UNITS, trust=True, ) - + juju.deploy(glide_runner_charm, app=GLIDE_RUNNER_NAME) juju.deploy(TLS_NAME, channel=TLS_CHANNEL) juju.wait( - lambda status: are_agents_idle(status, APP_NAME, idle_period=30, unit_count=NUM_UNITS), + lambda status: are_agents_idle( + status, + APP_NAME, + GLIDE_RUNNER_NAME, + idle_period=30, + unit_count={ + APP_NAME: NUM_UNITS, + GLIDE_RUNNER_NAME: 1, + }, + ), timeout=600, ) @@ -69,7 +81,7 @@ def test_invalid_private_key(juju: jubilant.Juju) -> None: ) -async def test_valid_private_key(juju: jubilant.Juju) -> None: +def test_valid_private_key(juju: jubilant.Juju) -> None: logger.info("Updating user secret with valid private key now") private_key = PrivateKey.generate().raw @@ -97,9 +109,10 @@ async def test_valid_private_key(juju: jubilant.Juju) -> None: download_client_certificate_from_unit(juju, APP_NAME) logger.info("Check access with TLS enabled") - addresses = get_cluster_addresses(juju, APP_NAME) - result = await set_key( - hostnames=addresses, + endpoints = get_cluster_endpoints(juju, APP_NAME) + result = set_key( + juju=juju, + endpoints=endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), tls_enabled=True, @@ -108,13 +121,17 @@ async def test_valid_private_key(juju: jubilant.Juju) -> None: ) assert result == "OK", "Failed to write data with TLS enabled" - assert await get_key( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=True, - key=TEST_KEY, - ) == bytes(TEST_VALUE, "utf-8"), "Failed to read data with TLS enabled" + assert ( + get_key( + juju=juju, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + tls_enabled=True, + key=TEST_KEY, + ) + == TEST_VALUE + ), "Failed to read data with TLS enabled" logger.info("Store current certificate before expiration") with open(TLS_KEY_FILE, "r") as key_file: @@ -123,7 +140,7 @@ async def test_valid_private_key(juju: jubilant.Juju) -> None: assert private_key_on_unit == private_key, "Expected user-provided private key to be used" -async def test_private_key_updated(juju: jubilant.Juju) -> None: +def test_private_key_updated(juju: jubilant.Juju) -> None: logger.info("Getting current private key and certificate") with open(TLS_KEY_FILE, "r") as key_file: current_private_key = key_file.read() @@ -147,9 +164,10 @@ async def test_private_key_updated(juju: jubilant.Juju) -> None: download_client_certificate_from_unit(juju, APP_NAME) logger.info("Check access with TLS enabled") - addresses = get_cluster_addresses(juju, APP_NAME) - result = await set_key( - hostnames=addresses, + endpoints = get_cluster_endpoints(juju, APP_NAME) + result = set_key( + juju=juju, + endpoints=endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), tls_enabled=True, @@ -158,13 +176,17 @@ async def test_private_key_updated(juju: jubilant.Juju) -> None: ) assert result == "OK", "Failed to write data with TLS enabled" - assert await get_key( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=True, - key=TEST_KEY, - ) == bytes(TEST_VALUE, "utf-8"), "Failed to read data with TLS enabled" + assert ( + get_key( + juju=juju, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + tls_enabled=True, + key=TEST_KEY, + ) + == TEST_VALUE + ), "Failed to read data with TLS enabled" logger.info("Getting and comparing updated private key and certificate") with open(TLS_KEY_FILE, "r") as key_file: diff --git a/tests/integration/tls/test_tls.py b/tests/integration/tls/test_tls.py index 2358a5c..e40ef82 100644 --- a/tests/integration/tls/test_tls.py +++ b/tests/integration/tls/test_tls.py @@ -4,11 +4,11 @@ import logging import jubilant -import pytest from literals import CharmUsers, Substrate from tests.integration.helpers import ( APP_NAME, + GLIDE_RUNNER_NAME, IMAGE_RESOURCE, TLS_CHANNEL, TLS_NAME, @@ -16,7 +16,7 @@ are_apps_active_and_agents_idle, auth_test, download_client_certificate_from_unit, - get_cluster_addresses, + get_cluster_endpoints, get_key, get_password, set_key, @@ -29,7 +29,9 @@ TEST_VALUE = "test_value" -def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) -> None: +def test_build_and_deploy( + charm: str, juju: jubilant.Juju, substrate: Substrate, glide_runner_charm: str +) -> None: """Deploy the charm under test and a TLS provider.""" juju.deploy( charm, @@ -37,23 +39,34 @@ def test_build_and_deploy(charm: str, juju: jubilant.Juju, substrate: Substrate) num_units=NUM_UNITS, trust=True, ) + juju.deploy(glide_runner_charm, app=GLIDE_RUNNER_NAME) juju.deploy(TLS_NAME, channel=TLS_CHANNEL) juju.integrate(f"{APP_NAME}:client-certificates", TLS_NAME) juju.wait( - lambda status: are_agents_idle(status, APP_NAME, idle_period=30, unit_count=NUM_UNITS), + lambda status: are_agents_idle( + status, + APP_NAME, + GLIDE_RUNNER_NAME, + idle_period=30, + unit_count={ + APP_NAME: NUM_UNITS, + GLIDE_RUNNER_NAME: 1, + }, + ), timeout=600, ) -async def test_tls_enabled(juju: jubilant.Juju) -> None: +def test_tls_enabled(juju: jubilant.Juju) -> None: """Check if the TLS has been enabled on app startup.""" logger.info("Downloading TLS certificates from deployed app.") download_client_certificate_from_unit(juju, APP_NAME) - addresses = get_cluster_addresses(juju, APP_NAME) + endpoints = get_cluster_endpoints(juju, APP_NAME) logger.info("Check access with TLS enabled") - result = await set_key( - hostnames=addresses, + result = set_key( + juju=juju, + endpoints=endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), tls_enabled=True, @@ -62,18 +75,21 @@ async def test_tls_enabled(juju: jubilant.Juju) -> None: ) assert result == "OK", "Failed to write data with TLS enabled" - assert await get_key( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=True, - key=TEST_KEY, - ) == bytes(TEST_VALUE, "utf-8"), "Failed to read data with TLS enabled" + assert ( + get_key( + juju=juju, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + tls_enabled=True, + key=TEST_KEY, + ) + == TEST_VALUE + ), "Failed to read data with TLS enabled" logger.info("Check access without certs fails when TLS enabled") - with pytest.raises(Exception) as exc_info: - await auth_test(addresses, username=None, password=None) - assert "Connection error" in str(exc_info.value), "Access without TLS did not fail as expected" + + assert not auth_test(juju, endpoints, username=None, password=None) def test_scale_up_with_tls_enabled(juju: jubilant.Juju) -> None: @@ -88,7 +104,7 @@ def test_scale_up_with_tls_enabled(juju: jubilant.Juju) -> None: ) -async def test_disable_tls(juju: jubilant.Juju) -> None: +def test_disable_tls(juju: jubilant.Juju) -> None: """Disable TLS on a running cluster and check if it is still accessible.""" logger.info("Removing client-certificates relation") juju.remove_relation(f"{APP_NAME}:client-certificates", f"{TLS_NAME}:certificates") @@ -98,10 +114,11 @@ async def test_disable_tls(juju: jubilant.Juju) -> None: timeout=600, ) - addresses = get_cluster_addresses(juju, APP_NAME) + endpoints = get_cluster_endpoints(juju, APP_NAME) logger.info("Check access with TLS disabled") - result = await set_key( - hostnames=addresses, + result = set_key( + juju=juju, + endpoints=endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), tls_enabled=False, @@ -110,16 +127,20 @@ async def test_disable_tls(juju: jubilant.Juju) -> None: ) assert result == "OK", "Failed to write data after TLS was disabled" - assert await get_key( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=False, - key=TEST_KEY, - ) == bytes(TEST_VALUE, "utf-8"), "Failed to read data after TLS was disabled" + assert ( + get_key( + juju=juju, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + tls_enabled=False, + key=TEST_KEY, + ) + == TEST_VALUE + ), "Failed to read data after TLS was disabled" -async def test_enable_tls(juju: jubilant.Juju) -> None: +def test_enable_tls(juju: jubilant.Juju) -> None: """Enable TLS on a running cluster and check if it is still accessible.""" logger.info("Enabling client TLS") juju.integrate(f"{APP_NAME}:client-certificates", TLS_NAME) @@ -131,10 +152,11 @@ async def test_enable_tls(juju: jubilant.Juju) -> None: logger.info("Downloading TLS certificates from deployed app.") download_client_certificate_from_unit(juju, APP_NAME) - addresses = get_cluster_addresses(juju, APP_NAME) + endpoints = get_cluster_endpoints(juju, APP_NAME) logger.info("Check access with TLS enabled") - result = await set_key( - hostnames=addresses, + result = set_key( + juju=juju, + endpoints=endpoints, username=CharmUsers.VALKEY_ADMIN.value, password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), tls_enabled=True, @@ -143,15 +165,17 @@ async def test_enable_tls(juju: jubilant.Juju) -> None: ) assert result == "OK", "Failed to write data with TLS enabled" - assert await get_key( - hostnames=addresses, - username=CharmUsers.VALKEY_ADMIN.value, - password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), - tls_enabled=True, - key=TEST_KEY, - ) == bytes(TEST_VALUE, "utf-8"), "Failed to read data with TLS enabled" + assert ( + get_key( + juju=juju, + endpoints=endpoints, + username=CharmUsers.VALKEY_ADMIN.value, + password=get_password(juju, user=CharmUsers.VALKEY_ADMIN), + tls_enabled=True, + key=TEST_KEY, + ) + == TEST_VALUE + ), "Failed to read data with TLS enabled" logger.info("Check access without certs fails when TLS enabled") - with pytest.raises(Exception) as exc_info: - await auth_test(addresses, username=None, password=None) - assert "Connection error" in str(exc_info.value), "Access without TLS did not fail as expected" + assert not auth_test(juju, endpoints, username=None, password=None)