Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions src/tcp_health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import socket
import sys
import time
from logging.handlers import TimedRotatingFileHandler

PROTOCOL_VERSION = "1.0"

Expand Down Expand Up @@ -117,10 +118,10 @@ def tcp_check(
port = int(port)

try:
with socket.create_connection((host, port), timeout=5):
logging.info(f"TCP check successful for {server}")
with socket.create_connection((host, port), timeout=2):
logging.info(f"Network is reachable on {server}")
except (socket.timeout, socket.error) as e:
logging.warning(f"TCP check failed for {server}: {e}")
logging.warning(f"Network is not reachable on {server}: {e}")
nic_down = True

# Load current failure count from state file
Expand Down Expand Up @@ -148,22 +149,25 @@ def tcp_check(
else:
# Reset failure count on success
if failure_count > 0:
logging.info("✅ All servers reachable. Resetting failure count.")
logging.info("Network is reachable. Resetting failure count.")
write_failure_count(state_file, 0)
else:
logging.info("✅ All servers reachable. No alert triggered.")


def _configure_logging():
"""Configure logging so output is visible when Consul runs this script as a subprocess."""
fmt = "%(levelname)s: %(message)s"
logging.basicConfig(level=logging.DEBUG, format=fmt, stream=sys.stderr, force=True)
fmt = "%(asctime)s %(levelname)s %(message)s"
datefmt = "%Y-%m-%d %H:%M:%S"
formatter = logging.Formatter(fmt, datefmt=datefmt)
logging.basicConfig(
level=logging.DEBUG, format=fmt, datefmt=datefmt, stream=sys.stderr, force=True
)
snap_data = os.environ.get("SNAP_DATA")
if snap_data:
handler = logging.FileHandler(
os.path.join(snap_data, "tcp_health_check.log"), encoding="utf-8"
log_path = os.path.join(snap_data, "tcp_health_check.log")
handler = TimedRotatingFileHandler(
log_path, when="midnight", backupCount=3, encoding="utf-8"
)
handler.setFormatter(logging.Formatter(fmt))
handler.setFormatter(formatter)
logging.getLogger().addHandler(handler)


Expand Down
32 changes: 16 additions & 16 deletions tests/unit/test_tcp_health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def test_tcp_check_all_servers_reachable(self, mock_create_connection, caplog):
if call[0] # Only calls with positional arguments
]
assert len(connection_calls) == 3
assert connection_calls[0] == call(("10.0.0.1", 8301), timeout=5)
assert connection_calls[1] == call(("10.0.0.2", 8301), timeout=5)
assert connection_calls[2] == call(("10.0.0.3", 8301), timeout=5)
assert connection_calls[0] == call(("10.0.0.1", 8301), timeout=2)
assert connection_calls[1] == call(("10.0.0.2", 8301), timeout=2)
assert connection_calls[2] == call(("10.0.0.3", 8301), timeout=2)

# Verify success message
assert "✅ All servers reachable. No alert triggered." in caplog.text
assert "Network is reachable on" in caplog.text

@patch("tcp_health_check.send_nic_down_alert")
@patch("tcp_health_check.socket.create_connection")
Expand All @@ -163,7 +163,7 @@ def side_effect(*args, **kwargs):

# Verify alert was sent
mock_send_alert.assert_called_once_with("data/socket.sock")
assert "TCP check failed for 10.0.0.2:8301" in caplog.text
assert "Network is not reachable on 10.0.0.2:8301" in caplog.text

@patch("tcp_health_check.send_nic_down_alert")
@patch("tcp_health_check.socket.create_connection")
Expand All @@ -179,8 +179,8 @@ def test_tcp_check_all_servers_unreachable(

# Verify alert was sent
mock_send_alert.assert_called_once_with("data/socket.sock")
assert "TCP check failed for 10.0.0.1:8301" in caplog.text
assert "TCP check failed for 10.0.0.2:8301" in caplog.text
assert "Network is not reachable on 10.0.0.1:8301" in caplog.text
assert "Network is not reachable on 10.0.0.2:8301" in caplog.text

@patch("tcp_health_check.send_nic_down_alert")
@patch("tcp_health_check.socket.create_connection")
Expand All @@ -194,7 +194,7 @@ def test_tcp_check_socket_timeout(self, mock_create_connection, mock_send_alert,

# Verify alert was sent
mock_send_alert.assert_called_once_with("data/socket.sock")
assert "TCP check failed for 10.0.0.1:8301" in caplog.text
assert "Network is not reachable on 10.0.0.1:8301" in caplog.text

@patch("tcp_health_check.socket.create_connection")
def test_tcp_check_no_socket_path_on_failure(self, mock_create_connection, caplog):
Expand All @@ -218,8 +218,8 @@ def test_tcp_check_single_server(self, mock_create_connection, caplog):
tcp_check(servers, "data/socket.sock")

# Verify server was checked
mock_create_connection.assert_called_once_with(("192.168.1.10", 9301), timeout=5)
assert "TCP check successful for 192.168.1.10:9301" in caplog.text
mock_create_connection.assert_called_once_with(("192.168.1.10", 9301), timeout=2)
assert "Network is reachable on 192.168.1.10:9301" in caplog.text

@patch("tcp_health_check.send_nic_down_alert")
@patch("tcp_health_check.socket.create_connection")
Expand All @@ -246,9 +246,9 @@ def side_effect(*args, **kwargs):

# Verify alert was sent (because at least one failed)
mock_send_alert.assert_called_once_with("data/socket.sock")
assert "TCP check successful for 10.0.0.1:8301" in caplog.text
assert "TCP check failed for 10.0.0.2:8301" in caplog.text
assert "TCP check successful for 10.0.0.3:8301" in caplog.text
assert "Network is reachable on 10.0.0.1:8301" in caplog.text
assert "Network is not reachable on 10.0.0.2:8301" in caplog.text
assert "Network is reachable on 10.0.0.3:8301" in caplog.text

@patch("tcp_health_check.socket.create_connection")
def test_tcp_check_different_ports(self, mock_create_connection, caplog):
Expand All @@ -266,9 +266,9 @@ def test_tcp_check_different_ports(self, mock_create_connection, caplog):
if call[0] # Only calls with positional arguments
]
assert len(connection_calls) == 3
assert connection_calls[0] == call(("10.0.0.1", 8301), timeout=5)
assert connection_calls[1] == call(("10.0.0.2", 9301), timeout=5)
assert connection_calls[2] == call(("10.0.0.3", 7301), timeout=5)
assert connection_calls[0] == call(("10.0.0.1", 8301), timeout=2)
assert connection_calls[1] == call(("10.0.0.2", 9301), timeout=2)
assert connection_calls[2] == call(("10.0.0.3", 7301), timeout=2)

@patch("tcp_health_check.socket.create_connection")
def test_tcp_check_ipv6_address(self, mock_create_connection, caplog):
Expand Down
Loading