diff --git a/kubernetes/src/charm.py b/kubernetes/src/charm.py index c8bb4a94a..4341e3c0a 100755 --- a/kubernetes/src/charm.py +++ b/kubernetes/src/charm.py @@ -15,7 +15,6 @@ import logging import random -from socket import getfqdn from time import sleep import ops @@ -101,7 +100,7 @@ from relations.mysql_root import MySQLRootRelation from rotate_mysql_logs import RotateMySQLLogs, RotateMySQLLogsCharmEvents from upgrade import MySQLK8sUpgrade, get_mysql_k8s_dependencies_model -from utils import compare_dictionaries, dotappend, generate_random_password +from utils import compare_dictionaries, dotappend, generate_random_password, get_k8s_fqdn logger = logging.getLogger(__name__) @@ -334,7 +333,7 @@ def get_unit_address(self, unit: Unit, relation_name: str = PEER) -> str: Translate juju unit name to resolvable hostname. """ unit_hostname = self.get_unit_hostname(unit.name) - unit_dns_domain = getfqdn(self.get_unit_hostname(unit.name)) + unit_dns_domain = get_k8s_fqdn(self.get_unit_hostname(unit.name)) # When fully propagated, DNS domain name should contain unit hostname. # For example: @@ -1072,7 +1071,7 @@ def _on_database_storage_detaching(self, _) -> None: logger.info("Switching primary to unit 0") try: self._mysql.set_cluster_primary( - new_primary_address=getfqdn(self.get_unit_hostname(f"{self.app.name}/0")) + new_primary_address=get_k8s_fqdn(self.get_unit_hostname(f"{self.app.name}/0")) ) except MySQLSetClusterPrimaryError: logger.warning("Failed to switch primary to unit 0") diff --git a/kubernetes/src/relations/mysql_provider.py b/kubernetes/src/relations/mysql_provider.py index 5d56cf640..57f8c9ded 100644 --- a/kubernetes/src/relations/mysql_provider.py +++ b/kubernetes/src/relations/mysql_provider.py @@ -4,7 +4,6 @@ """Library containing the implementation of the standard relation.""" import logging -import socket import typing from charms.data_platform_libs.v0.data_interfaces import DatabaseProvides, DatabaseRequestedEvent @@ -24,7 +23,7 @@ from constants import CONTAINER_NAME, CONTAINER_RESTARTS, DB_RELATION_NAME, PASSWORD_LENGTH from k8s_helpers import KubernetesClientError -from utils import dotappend, generate_random_password +from utils import dotappend, generate_random_password, get_k8s_fqdn logger = logging.getLogger(__name__) @@ -127,8 +126,8 @@ def _on_database_requested(self, event: DatabaseRequestedEvent) -> None: # create k8s services for endpoints self.charm.k8s_helpers.create_endpoint_services(["primary", "replicas"]) - primary_endpoint = dotappend(socket.getfqdn(f"{self.charm.app.name}-primary")) - replicas_endpoint = dotappend(socket.getfqdn(f"{self.charm.app.name}-replicas")) + primary_endpoint = dotappend(get_k8s_fqdn(f"{self.charm.app.name}-primary")) + replicas_endpoint = dotappend(get_k8s_fqdn(f"{self.charm.app.name}-replicas")) db_version = self.charm._mysql.get_mysql_version() diff --git a/kubernetes/src/upgrade.py b/kubernetes/src/upgrade.py index 8e442d59b..d49b978f9 100644 --- a/kubernetes/src/upgrade.py +++ b/kubernetes/src/upgrade.py @@ -5,7 +5,6 @@ import json import logging -from socket import getfqdn from typing import TYPE_CHECKING from charms.data_platform_libs.v0.upgrade import ( @@ -33,6 +32,7 @@ import k8s_helpers from constants import CONTAINER_NAME, MYSQLD_SERVICE +from utils import get_k8s_fqdn if TYPE_CHECKING: from charm import MySQLOperatorCharm @@ -163,7 +163,7 @@ def _pre_upgrade_prepare(self) -> None: """ if self.charm._mysql.get_primary_label() != f"{self.charm.app.name}-0": # set the primary to the first unit for switchover mitigation - new_primary = getfqdn(self.charm.get_unit_hostname(f"{self.charm.app.name}/0")) + new_primary = get_k8s_fqdn(self.charm.get_unit_hostname(f"{self.charm.app.name}/0")) self.charm._mysql.set_cluster_primary(new_primary) # set slow shutdown on all instances diff --git a/kubernetes/src/utils.py b/kubernetes/src/utils.py index 47ebe1fa5..8e4a984b0 100644 --- a/kubernetes/src/utils.py +++ b/kubernetes/src/utils.py @@ -5,6 +5,7 @@ import re import secrets +import socket import string @@ -82,3 +83,23 @@ def dotappend(string: str) -> str: if not string.endswith("."): string += "." return string + + +def get_k8s_fqdn(name: str) -> str: + """Resolve the canonical FQDN for a Kubernetes service or pod name.""" + try: + info = socket.getaddrinfo( + name, + None, + family=socket.AF_UNSPEC, + flags=socket.AI_CANONNAME, + type=socket.SOCK_STREAM, + ) + except socket.gaierror as e: + raise RuntimeError(f"Failed to resolve canonical name for {name}") from e + + for entry in info: + if canonname := entry[3]: + return canonname + + raise RuntimeError(f"Could not determine canonical name for {name}") diff --git a/kubernetes/tests/unit/test_charm.py b/kubernetes/tests/unit/test_charm.py index e473d4bbd..b2ba5c5ac 100644 --- a/kubernetes/tests/unit/test_charm.py +++ b/kubernetes/tests/unit/test_charm.py @@ -313,6 +313,17 @@ def test_on_config_changed(self): self.charm.peers.data[self.charm.app]["cluster-name"], "not_valid_cluster_name" ) + @patch( + "charm.get_k8s_fqdn", + return_value="mysql-k8s-0.mysql-k8s-endpoints.default.svc.cluster.local", + ) + def test_get_unit_address(self, mock_get_k8s_fqdn): + self.assertEqual( + self.charm.get_unit_address(self.charm.unit), + "mysql-k8s-0.mysql-k8s-endpoints.default.svc.cluster.local.", + ) + mock_get_k8s_fqdn.assert_called_once_with("mysql-k8s-0.mysql-k8s-endpoints") + @patch("charm.MySQLOperatorCharm.get_unit_address", return_value="mysql-k8s.somedomain") @patch("mysql_k8s_helpers.MySQL.is_data_dir_initialised", return_value=False) def test_mysql_property(self, _, mock_get_unit_address): diff --git a/kubernetes/tests/unit/test_database.py b/kubernetes/tests/unit/test_database.py index cb513b5b9..2641fc1bc 100644 --- a/kubernetes/tests/unit/test_database.py +++ b/kubernetes/tests/unit/test_database.py @@ -67,8 +67,10 @@ def tearDown(self) -> None: @patch( "relations.mysql_provider.generate_random_password", return_value="super_secure_password" ) + @patch("relations.mysql_provider.get_k8s_fqdn") def test_database_requested( self, + mock_get_k8s_fqdn, _generate_random_password, _create_scoped_user, _create_database, @@ -80,6 +82,8 @@ def test_database_requested( _cluster_metadata_exists, _get_unit_address, ): + mock_get_k8s_fqdn.side_effect = ["mysql-k8s-primary", "mysql-k8s-replicas"] + # run start-up events to enable usage of the helper class self.harness.set_leader(True) self.harness.container_pebble_ready("mysql") @@ -124,3 +128,4 @@ def test_database_requested( _create_endpoint_services.assert_called_once() _update_endpoints.assert_called() _wait_service_ready.assert_called_once() + self.assertEqual(mock_get_k8s_fqdn.call_count, 2) diff --git a/kubernetes/tests/unit/test_upgrade.py b/kubernetes/tests/unit/test_upgrade.py index 305ff2ff1..664dfc07d 100644 --- a/kubernetes/tests/unit/test_upgrade.py +++ b/kubernetes/tests/unit/test_upgrade.py @@ -144,10 +144,14 @@ def test_log_rollback(self, mock_logging): @patch("mysql_k8s_helpers.MySQL.set_dynamic_variable") @patch("mysql_k8s_helpers.MySQL.get_primary_label", return_value="mysql-k8s-1") @patch("mysql_k8s_helpers.MySQL.set_cluster_primary") + @patch( + "upgrade.get_k8s_fqdn", return_value="mysql-k8s-0.mysql-k8s-endpoints.svc.cluster.local" + ) @patch("k8s_helpers.KubernetesHelpers.set_rolling_update_partition") def test_pre_upgrade_prepare( self, mock_set_rolling_update_partition, + mock_get_k8s_fqdn, mock_set_cluster_primary, mock_get_primary_label, mock_set_dynamic_variable, @@ -159,7 +163,10 @@ def test_pre_upgrade_prepare( self.charm.upgrade._pre_upgrade_prepare() - mock_set_cluster_primary.assert_called_once() + mock_set_cluster_primary.assert_called_once_with( + "mysql-k8s-0.mysql-k8s-endpoints.svc.cluster.local" + ) + mock_get_k8s_fqdn.assert_called_once_with("mysql-k8s-0.mysql-k8s-endpoints") mock_get_primary_label.assert_called_once() mock_set_rolling_update_partition.assert_called_once() assert mock_set_dynamic_variable.call_count == 2 diff --git a/kubernetes/tests/unit/test_utils.py b/kubernetes/tests/unit/test_utils.py index 03f969248..ec68fb283 100644 --- a/kubernetes/tests/unit/test_utils.py +++ b/kubernetes/tests/unit/test_utils.py @@ -1,9 +1,11 @@ # Copyright 2023 Canonical Ltd. # See LICENSE file for licensing details. +import socket import unittest +from unittest.mock import patch -from utils import any_memory_to_bytes, generate_random_password, split_mem +from utils import any_memory_to_bytes, generate_random_password, get_k8s_fqdn, split_mem class TestUtils(unittest.TestCase): @@ -21,3 +23,81 @@ def test_any_memory_to_bytes(self): self.assertEqual(any_memory_to_bytes("1Gi"), 1073741824) self.assertEqual(any_memory_to_bytes("1G"), 10**9) self.assertEqual(any_memory_to_bytes("1024"), 1024) + + @patch("utils.socket.getaddrinfo") + def test_get_k8s_fqdn(self, mock_getaddrinfo): + mock_getaddrinfo.return_value = [ + ( + None, + None, + None, + "", + None, + ), + ( + None, + None, + None, + "mysql-2.mysql-endpoints.default.svc.cluster.local.", + None, + ), + ] + + self.assertEqual( + get_k8s_fqdn("mysql-2.mysql-endpoints"), + "mysql-2.mysql-endpoints.default.svc.cluster.local.", + ) + mock_getaddrinfo.assert_called_once_with( + "mysql-2.mysql-endpoints", + None, + family=socket.AF_UNSPEC, + flags=socket.AI_CANONNAME, + type=socket.SOCK_STREAM, + ) + + @patch("utils.socket.getaddrinfo", side_effect=socket.gaierror) + def test_get_k8s_fqdn_resolution_error(self, mock_getaddrinfo): + with self.assertRaisesRegex( + RuntimeError, "Failed to resolve canonical name for mysql-2.mysql-endpoints" + ): + get_k8s_fqdn("mysql-2.mysql-endpoints") + + mock_getaddrinfo.assert_called_once_with( + "mysql-2.mysql-endpoints", + None, + family=socket.AF_UNSPEC, + flags=socket.AI_CANONNAME, + type=socket.SOCK_STREAM, + ) + + @patch("utils.socket.getaddrinfo") + def test_get_k8s_fqdn_without_canonical_name(self, mock_getaddrinfo): + mock_getaddrinfo.return_value = [ + ( + None, + None, + None, + "", + None, + ), + ( + None, + None, + None, + "", + None, + ), + ] + + with self.assertRaisesRegex( + RuntimeError, "Could not determine canonical name for mysql-2.mysql-endpoints" + ): + get_k8s_fqdn("mysql-2.mysql-endpoints") + + mock_getaddrinfo.assert_called_once_with( + "mysql-2.mysql-endpoints", + None, + family=socket.AF_UNSPEC, + flags=socket.AI_CANONNAME, + type=socket.SOCK_STREAM, + )