diff --git a/core/testcontainers/core/wait_strategies.py b/core/testcontainers/core/wait_strategies.py index a9627548..a1f5b112 100644 --- a/core/testcontainers/core/wait_strategies.py +++ b/core/testcontainers/core/wait_strategies.py @@ -27,17 +27,18 @@ """ import re +import socket import time from datetime import timedelta -from typing import TYPE_CHECKING, Union +from pathlib import Path +from typing import Any, Callable, Optional, Union +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen from testcontainers.core.utils import setup_logger # Import base classes from waiting_utils to make them available for tests -from .waiting_utils import WaitStrategy - -if TYPE_CHECKING: - from .waiting_utils import WaitStrategyTarget +from .waiting_utils import WaitStrategy, WaitStrategyTarget logger = setup_logger(__name__) @@ -155,3 +156,671 @@ def wait_until_ready(self, container: "WaitStrategyTarget") -> None: ) time.sleep(self._poll_interval) + + +class HttpWaitStrategy(WaitStrategy): + """ + Wait for an HTTP endpoint to be available and return expected status code(s). + + This strategy makes HTTP requests to a specified endpoint and waits for it to + return an acceptable status code. It supports various HTTP methods, headers, + authentication, and custom response validation. + + Args: + port: The port number to connect to + path: The HTTP path to request (default: "/") + + Example: + # Basic HTTP check + strategy = HttpWaitStrategy(8080).for_status_code(200) + + # HTTPS with custom path + strategy = HttpWaitStrategy(443, "/health").using_tls().for_status_code(200) + + # Custom validation + strategy = HttpWaitStrategy(8080).for_response_predicate(lambda body: "ready" in body) + + # Create from URL + strategy = HttpWaitStrategy.from_url("https://localhost:8080/api/health") + """ + + def __init__(self, port: int, path: Optional[str] = "/") -> None: + super().__init__() + self._port = port + self._path = "/" if path is None else (path if path.startswith("/") else f"/{path}") + self._status_codes: set[int] = {200} + self._status_code_predicate: Optional[Callable[[int], bool]] = None + self._tls = False + self._headers: dict[str, str] = {} + self._basic_auth: Optional[tuple[str, str]] = None + self._response_predicate: Optional[Callable[[str], bool]] = None + self._method = "GET" + self._body: Optional[str] = None + self._insecure_tls = False + + def with_startup_timeout(self, timeout: Union[int, timedelta]) -> "HttpWaitStrategy": + """Set the maximum time to wait for the container to be ready.""" + if isinstance(timeout, timedelta): + self._startup_timeout = int(timeout.total_seconds()) + else: + self._startup_timeout = timeout + return self + + def with_poll_interval(self, interval: Union[float, timedelta]) -> "HttpWaitStrategy": + """Set how frequently to check if the container is ready.""" + if isinstance(interval, timedelta): + self._poll_interval = interval.total_seconds() + else: + self._poll_interval = interval + return self + + @classmethod + def from_url(cls, url: str) -> "HttpWaitStrategy": + """ + Create an HttpWaitStrategy from a URL string. + + Args: + url: The URL to wait for (e.g., "http://localhost:8080/api/health") + + Returns: + An HttpWaitStrategy configured for the given URL + + Example: + strategy = HttpWaitStrategy.from_url("https://localhost:8080/api/health") + """ + from urllib.parse import urlparse + + parsed = urlparse(url) + port = parsed.port or (443 if parsed.scheme == "https" else 80) + path = parsed.path or "/" + + strategy = cls(port, path) + + if parsed.scheme == "https": + strategy.using_tls() + + return strategy + + def for_status_code(self, code: int) -> "HttpWaitStrategy": + """ + Add an acceptable status code. + + Args: + code: HTTP status code to accept + + Returns: + self for method chaining + """ + self._status_codes.add(code) + return self + + def for_status_code_matching(self, predicate: Callable[[int], bool]) -> "HttpWaitStrategy": + """ + Set a predicate to match status codes against. + + Args: + predicate: Function that takes a status code and returns True if acceptable + + Returns: + self for method chaining + """ + self._status_code_predicate = predicate + return self + + def for_response_predicate(self, predicate: Callable[[str], bool]) -> "HttpWaitStrategy": + """ + Set a predicate to match response body against. + + Args: + predicate: Function that takes response body and returns True if acceptable + + Returns: + self for method chaining + """ + self._response_predicate = predicate + return self + + def using_tls(self, insecure: bool = False) -> "HttpWaitStrategy": + """ + Use HTTPS instead of HTTP. + + Args: + insecure: If True, skip SSL certificate verification + + Returns: + self for method chaining + """ + self._tls = True + self._insecure_tls = insecure + return self + + def with_header(self, name: str, value: str) -> "HttpWaitStrategy": + """ + Add a header to the request. + + Args: + name: Header name + value: Header value + + Returns: + self for method chaining + """ + self._headers[name] = value + return self + + def with_basic_credentials(self, username: str, password: str) -> "HttpWaitStrategy": + """ + Add basic auth credentials. + + Args: + username: Basic auth username + password: Basic auth password + + Returns: + self for method chaining + """ + self._basic_auth = (username, password) + return self + + def with_method(self, method: str) -> "HttpWaitStrategy": + """ + Set the HTTP method to use. + + Args: + method: HTTP method (GET, POST, PUT, etc.) + + Returns: + self for method chaining + """ + self._method = method.upper() + return self + + def with_body(self, body: str) -> "HttpWaitStrategy": + """ + Set the request body. + + Args: + body: Request body as string + + Returns: + self for method chaining + """ + self._body = body + return self + + def _setup_headers(self) -> dict[str, str]: + """Set up headers for the HTTP request.""" + import base64 + + headers = self._headers.copy() + if self._basic_auth: + auth = base64.b64encode(f"{self._basic_auth[0]}:{self._basic_auth[1]}".encode()).decode() + headers["Authorization"] = f"Basic {auth}" + return headers + + def _setup_ssl_context(self) -> Optional[Any]: + """Set up SSL context if needed.""" + import ssl + + if self._tls and self._insecure_tls: + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + return ssl_context + return None + + def _build_url(self, container: WaitStrategyTarget) -> str: + """Build the URL for the HTTP request.""" + protocol = "https" if self._tls else "http" + host = container.get_container_host_ip() + port = int(container.get_exposed_port(self._port)) + return f"{protocol}://{host}:{port}{self._path}" + + def _check_response(self, response: Any, url: str) -> bool: + """Check if the response is acceptable.""" + status_code = response.status + + # Check status code matches + if status_code in self._status_codes or ( + self._status_code_predicate and self._status_code_predicate(status_code) + ): + # Check response body if needed + if self._response_predicate is not None: + body = response.read().decode() + return self._response_predicate(body) + return True + else: + raise HTTPError(url, status_code, "Unexpected status code", response.headers, None) + + def wait_until_ready(self, container: WaitStrategyTarget) -> None: + """ + Wait until the HTTP endpoint is ready and returns an acceptable response. + + Args: + container: The container to monitor + + Raises: + TimeoutError: If the endpoint doesn't become ready within the timeout period + """ + start_time = time.time() + headers = self._setup_headers() + ssl_context = self._setup_ssl_context() + url = self._build_url(container) + + while True: + if time.time() - start_time > self._startup_timeout: + self._raise_timeout_error(url) + + if self._try_http_request(url, headers, ssl_context): + return + + time.sleep(self._poll_interval) + + def _raise_timeout_error(self, url: str) -> None: + """Raise a timeout error with detailed information.""" + raise TimeoutError( + f"HTTP endpoint not ready within {self._startup_timeout} seconds. " + f"Endpoint: {url}. " + f"Method: {self._method}. " + f"Expected status codes: {self._status_codes}. " + f"Hint: Check if the service is listening on port {self._port}, " + f"the endpoint path is correct, and the service is configured to respond to {self._method} requests." + ) + + def _try_http_request(self, url: str, headers: dict[str, str], ssl_context: Any) -> bool: + """Try to make an HTTP request and return True if successful.""" + try: + request = Request( + url, + headers=headers, + method=self._method, + data=self._body.encode() if self._body else None, + ) + + with urlopen(request, timeout=1, context=ssl_context) as response: + return self._check_response(response, url) + + except (URLError, HTTPError) as e: + return self._handle_http_error(e) + except (ConnectionResetError, ConnectionRefusedError, BrokenPipeError, OSError) as e: + # Handle connection-level errors that can occur during HTTP requests + logger.debug(f"HTTP connection failed: {e!s}") + return False + + def _handle_http_error(self, error: Union[URLError, HTTPError]) -> bool: + """Handle HTTP errors and return True if error is acceptable.""" + if isinstance(error, HTTPError) and ( + error.code in self._status_codes + or (self._status_code_predicate and self._status_code_predicate(error.code)) + ): + return True + logger.debug(f"HTTP request failed: {error!s}") + return False + + +class HealthcheckWaitStrategy(WaitStrategy): + """ + Wait for the container's health check to report as healthy. + + This strategy monitors the container's Docker health check status and waits + for it to report as "healthy". It requires the container to have a health + check configured in its Dockerfile or container configuration. + + Example: + # Wait for container to be healthy + strategy = HealthcheckWaitStrategy() + + # With custom timeout + strategy = HealthcheckWaitStrategy().with_startup_timeout(60) + + Note: + The container must have a HEALTHCHECK instruction in its Dockerfile + or health check configured during container creation for this strategy + to work. If no health check is configured, this strategy will raise + a RuntimeError. + """ + + def __init__(self) -> None: + super().__init__() + + def with_startup_timeout(self, timeout: Union[int, timedelta]) -> "HealthcheckWaitStrategy": + """Set the maximum time to wait for the container to be ready.""" + if isinstance(timeout, timedelta): + self._startup_timeout = int(timeout.total_seconds()) + else: + self._startup_timeout = timeout + return self + + def with_poll_interval(self, interval: Union[float, timedelta]) -> "HealthcheckWaitStrategy": + """Set how frequently to check if the container is ready.""" + if isinstance(interval, timedelta): + self._poll_interval = interval.total_seconds() + else: + self._poll_interval = interval + return self + + def wait_until_ready(self, container: WaitStrategyTarget) -> None: + """ + Wait until the container's health check reports as healthy. + + Args: + container: The container to monitor + + Raises: + TimeoutError: If the health check doesn't report healthy within the timeout period + RuntimeError: If no health check is configured or if the health check reports unhealthy + """ + start_time = time.time() + wrapped = container.get_wrapped_container() + + while True: + if time.time() - start_time > self._startup_timeout: + wrapped.reload() # Refresh container state + health = wrapped.attrs.get("State", {}).get("Health", {}) + status = health.get("Status") if health else "no health check" + raise TimeoutError( + f"Container health check did not report healthy within {self._startup_timeout} seconds. " + f"Current status: {status}. " + f"Hint: Check if the health check command is working correctly, " + f"the application is starting properly, and the health check interval is appropriate." + ) + + wrapped.reload() # Refresh container state + health = wrapped.attrs.get("State", {}).get("Health", {}) + + # No health check configured + if not health: + raise RuntimeError( + "No health check configured for container. " + "Add HEALTHCHECK instruction to Dockerfile or configure health check in container creation. " + "Example: HEALTHCHECK CMD curl -f http://localhost:8080/health || exit 1" + ) + + status = health.get("Status") + + if status == "healthy": + return + elif status == "unhealthy": + # Get the last health check log for better error reporting + log = health.get("Log", []) + last_log = log[-1] if log else {} + exit_code = last_log.get("ExitCode", "unknown") + output = last_log.get("Output", "no output") + + raise RuntimeError( + f"Container health check reported unhealthy. " + f"Exit code: {exit_code}, " + f"Output: {output}. " + f"Hint: Check the health check command, ensure the application is responding correctly, " + f"and verify the health check endpoint or command is working as expected." + ) + + time.sleep(self._poll_interval) + + +class PortWaitStrategy(WaitStrategy): + """ + Wait for a port to be available on the container. + + This strategy attempts to establish a TCP connection to a specified port + on the container and waits until the connection succeeds. It's useful for + waiting for services that need to be listening on a specific port. + + Args: + port: The port number to check for availability + + Example: + # Wait for port 8080 to be available + strategy = PortWaitStrategy(8080) + + # Wait for database port with custom timeout + strategy = PortWaitStrategy(5432).with_startup_timeout(30) + """ + + def __init__(self, port: int) -> None: + super().__init__() + self._port = port + + def with_startup_timeout(self, timeout: Union[int, timedelta]) -> "PortWaitStrategy": + """Set the maximum time to wait for the container to be ready.""" + if isinstance(timeout, timedelta): + self._startup_timeout = int(timeout.total_seconds()) + else: + self._startup_timeout = timeout + return self + + def with_poll_interval(self, interval: Union[float, timedelta]) -> "PortWaitStrategy": + """Set how frequently to check if the container is ready.""" + if isinstance(interval, timedelta): + self._poll_interval = interval.total_seconds() + else: + self._poll_interval = interval + return self + + def wait_until_ready(self, container: WaitStrategyTarget) -> None: + """ + Wait until the specified port is available for connection. + + Args: + container: The container to monitor + + Raises: + TimeoutError: If the port doesn't become available within the timeout period + """ + start_time = time.time() + host = container.get_container_host_ip() + port = int(container.get_exposed_port(self._port)) + + while True: + if time.time() - start_time > self._startup_timeout: + raise TimeoutError( + f"Port {self._port} not available within {self._startup_timeout} seconds. " + f"Attempted connection to {host}:{port}. " + f"Hint: Check if the service is configured to listen on port {self._port}, " + f"the service is starting correctly, and there are no firewall or network issues." + ) + + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(1) + s.connect((host, port)) + return + except (socket.timeout, ConnectionRefusedError, OSError): + time.sleep(self._poll_interval) + + +class FileExistsWaitStrategy(WaitStrategy): + """ + Wait for a file to exist on the host filesystem. + + This strategy waits for a specific file to exist on the host filesystem, + typically used for waiting for files created by containers via volume mounts. + This is useful for scenarios like Docker-in-Docker where certificate files + need to be generated before they can be used. + + Args: + file_path: Path to the file to wait for (can be str or Path object) + + Example: + # Wait for a certificate file + strategy = FileExistsWaitStrategy("/tmp/certs/ca.pem") + + # Wait for a configuration file + from pathlib import Path + strategy = FileExistsWaitStrategy(Path("/tmp/config/app.conf")) + """ + + def __init__(self, file_path: Union[str, Path]) -> None: + super().__init__() + self._file_path = Path(file_path) + + def with_startup_timeout(self, timeout: Union[int, timedelta]) -> "FileExistsWaitStrategy": + """Set the maximum time to wait for the container to be ready.""" + if isinstance(timeout, timedelta): + self._startup_timeout = int(timeout.total_seconds()) + else: + self._startup_timeout = timeout + return self + + def with_poll_interval(self, interval: Union[float, timedelta]) -> "FileExistsWaitStrategy": + """Set how frequently to check if the container is ready.""" + if isinstance(interval, timedelta): + self._poll_interval = interval.total_seconds() + else: + self._poll_interval = interval + return self + + def wait_until_ready(self, container: WaitStrategyTarget) -> None: + """ + Wait until the specified file exists on the host filesystem. + + Args: + container: The container (used for timeout/polling configuration) + + Raises: + TimeoutError: If the file doesn't exist within the timeout period + """ + start_time = time.time() + + logger.debug( + f"FileExistsWaitStrategy: Waiting for file {self._file_path} with timeout {self._startup_timeout}s" + ) + + while True: + if time.time() - start_time > self._startup_timeout: + # Check what files actually exist in the directory + parent_dir = self._file_path.parent + existing_files = [] + if parent_dir.exists(): + existing_files = [str(f) for f in parent_dir.rglob("*") if f.is_file()] + + logger.error(f"FileExistsWaitStrategy: File {self._file_path} not found after timeout") + logger.debug(f"FileExistsWaitStrategy: Parent directory exists: {parent_dir.exists()}") + logger.debug(f"FileExistsWaitStrategy: Files in parent directory: {existing_files}") + + raise TimeoutError( + f"File {self._file_path} did not exist within {self._startup_timeout:.3f} seconds. " + f"Parent directory exists: {parent_dir.exists()}. " + f"Files in parent directory: {existing_files}. " + f"Hint: Check if the container is configured to create the file at the expected location, " + f"and verify that volume mounts are set up correctly." + ) + + if self._file_path.is_file(): + logger.debug( + f"FileExistsWaitStrategy: File {self._file_path} found after {time.time() - start_time:.2f}s" + ) + return + + logger.debug( + f"FileExistsWaitStrategy: Polling - file {self._file_path} not found yet, elapsed: {time.time() - start_time:.2f}s" + ) + time.sleep(self._poll_interval) + + +class CompositeWaitStrategy(WaitStrategy): + """ + Wait for multiple conditions to be satisfied in sequence. + + This strategy allows combining multiple wait strategies that must all be satisfied. + Each strategy is executed in the order they were added, and all must succeed + for the composite strategy to be considered ready. + + Args: + strategies: Variable number of WaitStrategy objects to execute in sequence + + Example: + # Wait for log message AND file to exist + strategy = CompositeWaitStrategy( + LogMessageWaitStrategy("API listen on"), + FileExistsWaitStrategy("/tmp/certs/ca.pem") + ) + + # Wait for multiple conditions + strategy = CompositeWaitStrategy( + LogMessageWaitStrategy("Database ready"), + PortWaitStrategy(5432), + HttpWaitStrategy(8080, "/health").for_status_code(200) + ) + """ + + def __init__(self, *strategies: WaitStrategy) -> None: + super().__init__() + self._strategies = list(strategies) + + def with_startup_timeout(self, timeout: Union[int, timedelta]) -> "CompositeWaitStrategy": + """ + Set the startup timeout for all contained strategies. + + Args: + timeout: Maximum time to wait in seconds + + Returns: + self for method chaining + """ + if isinstance(timeout, timedelta): + self._startup_timeout = int(timeout.total_seconds()) + else: + self._startup_timeout = timeout + + for strategy in self._strategies: + strategy.with_startup_timeout(timeout) + return self + + def with_poll_interval(self, interval: Union[float, timedelta]) -> "CompositeWaitStrategy": + """ + Set the poll interval for all contained strategies. + + Args: + interval: How frequently to check in seconds + + Returns: + self for method chaining + """ + if isinstance(interval, timedelta): + self._poll_interval = interval.total_seconds() + else: + self._poll_interval = interval + + for strategy in self._strategies: + strategy.with_poll_interval(interval) + return self + + def wait_until_ready(self, container: WaitStrategyTarget) -> None: + """ + Wait until all contained strategies are ready. + + Args: + container: The container to monitor + + Raises: + TimeoutError: If any strategy doesn't become ready within the timeout period + """ + logger.debug(f"CompositeWaitStrategy: Starting execution of {len(self._strategies)} strategies") + + for i, strategy in enumerate(self._strategies): + try: + logger.debug( + f"CompositeWaitStrategy: Executing strategy {i + 1}/{len(self._strategies)}: {type(strategy).__name__}" + ) + strategy.wait_until_ready(container) + logger.debug(f"CompositeWaitStrategy: Strategy {i + 1}/{len(self._strategies)} completed successfully") + except TimeoutError as e: + logger.error(f"CompositeWaitStrategy: Strategy {i + 1}/{len(self._strategies)} failed: {e}") + raise TimeoutError( + f"Composite wait strategy failed at step {i + 1}/{len(self._strategies)}: {e}" + ) from e + + logger.debug("CompositeWaitStrategy: All strategies completed successfully") + + +__all__ = [ + "CompositeWaitStrategy", + "FileExistsWaitStrategy", + "HealthcheckWaitStrategy", + "HttpWaitStrategy", + "LogMessageWaitStrategy", + "PortWaitStrategy", + "WaitStrategy", + "WaitStrategyTarget", +] diff --git a/core/tests/test_wait_strategies.py b/core/tests/test_wait_strategies.py index 9ef4d258..19bfbc41 100644 --- a/core/tests/test_wait_strategies.py +++ b/core/tests/test_wait_strategies.py @@ -1,17 +1,21 @@ -import itertools import re import time -import typing from datetime import timedelta -from unittest.mock import Mock, patch - +from unittest.mock import Mock, patch, MagicMock import pytest +import itertools -from testcontainers.core.wait_strategies import LogMessageWaitStrategy -from testcontainers.core.waiting_utils import WaitStrategy - -if typing.TYPE_CHECKING: - from testcontainers.core.waiting_utils import WaitStrategyTarget +from testcontainers.core.container import DockerContainer +from testcontainers.core.wait_strategies import ( + CompositeWaitStrategy, + WaitStrategyTarget, + FileExistsWaitStrategy, + HealthcheckWaitStrategy, + HttpWaitStrategy, + LogMessageWaitStrategy, + PortWaitStrategy, + WaitStrategy, +) class ConcreteWaitStrategy(WaitStrategy): @@ -148,3 +152,536 @@ def test_wait_until_ready(self, mock_sleep, mock_time, container_logs, expected_ else: with pytest.raises(TimeoutError): strategy.wait_until_ready(mock_container) + + +class TestHttpWaitStrategy: + """Test the HttpWaitStrategy class.""" + + @pytest.mark.parametrize( + "port,path,expected_port,expected_path", + [ + (8080, "/health", 8080, "/health"), + (80, None, 80, "/"), + (443, "/api/status", 443, "/api/status"), + (3000, "", 3000, "/"), + ], + ids=[ + "port_8080_health_path", + "port_80_default_path", + "port_443_api_status_path", + "port_3000_empty_path", + ], + ) + def test_http_wait_strategy_initialization(self, port, path, expected_port, expected_path): + strategy = HttpWaitStrategy(port, path) + assert strategy._port == expected_port + assert strategy._path == expected_path + assert strategy._status_codes == {200} + assert strategy._method == "GET" + + @pytest.mark.parametrize( + "status_codes", + [ + [404], + [200, 201], + [500, 502, 503], + [200, 404, 500], + ], + ids=[ + "single_status_code_404", + "multiple_status_codes_200_201", + "error_status_codes_500_502_503", + "mixed_status_codes_200_404_500", + ], + ) + def test_for_status_code(self, status_codes): + strategy = HttpWaitStrategy(8080) + + for code in status_codes: + result = strategy.for_status_code(code) + assert result is strategy + assert code in strategy._status_codes + + # Default 200 should still be there + assert 200 in strategy._status_codes + + @pytest.mark.parametrize( + "predicate_description,status_code_predicate,response_predicate", + [ + ("status_code_range", lambda code: 200 <= code < 300, None), + ("status_code_equals_200", lambda code: code == 200, None), + ("response_contains_ready", None, lambda body: "ready" in body), + ("response_json_valid", None, lambda body: "status" in body), + ("both_predicates", lambda code: code >= 200, lambda body: len(body) > 0), + ], + ids=[ + "status_code_range_200_to_300", + "status_code_equals_200", + "response_contains_ready", + "response_json_valid", + "both_status_and_response_predicates", + ], + ) + def test_predicates(self, predicate_description, status_code_predicate, response_predicate): + strategy = HttpWaitStrategy(8080) + + if status_code_predicate: + result = strategy.for_status_code_matching(status_code_predicate) + assert result is strategy + assert strategy._status_code_predicate is status_code_predicate + + if response_predicate: + result = strategy.for_response_predicate(response_predicate) + assert result is strategy + assert strategy._response_predicate is response_predicate + + @pytest.mark.parametrize( + "tls_config,expected_tls,expected_insecure", + [ + ({"insecure": True}, True, True), + ({"insecure": False}, True, False), + ({}, True, False), + ], + ids=[ + "tls_insecure_true", + "tls_insecure_false", + "tls_default_insecure_false", + ], + ) + def test_using_tls(self, tls_config, expected_tls, expected_insecure): + strategy = HttpWaitStrategy(8080) + result = strategy.using_tls(**tls_config) + assert result is strategy + assert strategy._tls is expected_tls + assert strategy._insecure_tls is expected_insecure + + @pytest.mark.parametrize( + "headers", + [ + {"Authorization": "Bearer token"}, + {"Content-Type": "application/json"}, + {"User-Agent": "test", "Accept": "text/html"}, + ], + ids=[ + "single_header_authorization", + "single_header_content_type", + "multiple_headers_user_agent_accept", + ], + ) + def test_with_header(self, headers): + strategy = HttpWaitStrategy(8080) + + for key, value in headers.items(): + result = strategy.with_header(key, value) + assert result is strategy + assert strategy._headers[key] == value + + @pytest.mark.parametrize( + "credentials", + [ + ("user", "pass"), + ("admin", "secret123"), + ("test", ""), + ], + ids=[ + "basic_credentials_user_pass", + "basic_credentials_admin_secret", + "basic_credentials_test_empty", + ], + ) + def test_with_basic_credentials(self, credentials): + strategy = HttpWaitStrategy(8080) + result = strategy.with_basic_credentials(*credentials) + assert result is strategy + assert strategy._basic_auth == credentials + + @pytest.mark.parametrize( + "method", + [ + "GET", + "POST", + "PUT", + "DELETE", + "HEAD", + ], + ids=[ + "method_get", + "method_post", + "method_put", + "method_delete", + "method_head", + ], + ) + def test_with_method(self, method): + strategy = HttpWaitStrategy(8080) + result = strategy.with_method(method) + assert result is strategy + assert strategy._method == method + + @pytest.mark.parametrize( + "body", + [ + '{"key": "value"}', + '{"status": "ready"}', + "data=test&format=json", + "", + ], + ids=[ + "json_body_key_value", + "json_body_status_ready", + "form_data_body", + "empty_body", + ], + ) + def test_with_body(self, body): + strategy = HttpWaitStrategy(8080) + result = strategy.with_body(body) + assert result is strategy + assert strategy._body == body + + @pytest.mark.parametrize( + "url,expected_port,expected_path,expected_tls", + [ + ("https://localhost:8080/api/health", 8080, "/api/health", True), + ("http://localhost:3000", 3000, "/", False), + ("https://example.com", 443, "/", True), + ("http://localhost:80/", 80, "/", False), + ], + ids=[ + "https_localhost_8080_api_health", + "http_localhost_3000_default_path", + "https_example_com_default_port", + "http_localhost_80_root_path", + ], + ) + def test_from_url(self, url, expected_port, expected_path, expected_tls): + strategy = HttpWaitStrategy.from_url(url) + assert strategy._port == expected_port + assert strategy._path == expected_path + assert strategy._tls is expected_tls + + +class TestHealthcheckWaitStrategy: + """Test the HealthcheckWaitStrategy class.""" + + def test_healthcheck_wait_strategy_initialization(self): + strategy = HealthcheckWaitStrategy() + # Should inherit from WaitStrategy + assert hasattr(strategy, "_startup_timeout") + assert hasattr(strategy, "_poll_interval") + + @pytest.mark.parametrize( + "health_status,health_log,should_succeed,expected_error", + [ + ("healthy", None, True, None), + ( + "unhealthy", + [{"ExitCode": 1, "Output": "Health check failed"}], + False, + "Container health check reported unhealthy", + ), + ("starting", None, False, "Container health check did not report healthy within 120 seconds"), + (None, None, False, "No health check configured"), + ], + ids=[ + "healthy_status_success", + "unhealthy_status_failure", + "starting_status_failure", + "no_healthcheck_failure", + ], + ) + @patch("time.time") + @patch("time.sleep") + def test_wait_until_ready(self, mock_sleep, mock_time, health_status, health_log, should_succeed, expected_error): + strategy = HealthcheckWaitStrategy() + mock_container = Mock() + + # Mock the wrapped container + mock_wrapped = Mock() + mock_wrapped.status = "running" + mock_wrapped.reload.return_value = None + + # Mock health check data + health_data = {} + if health_status: + health_data = {"Status": health_status} + if health_log: + health_data["Log"] = health_log + + mock_wrapped.attrs = {"State": {"Health": health_data}} + + mock_container.get_wrapped_container.return_value = mock_wrapped + + # Configure time mock based on expected behavior + if should_succeed: + mock_time.side_effect = [0, 1] + else: + # For failure cases, we need more time values to handle the loop + mock_time.side_effect = itertools.count(start=0, step=1) # Provide enough values for the loop + + if should_succeed: + strategy.wait_until_ready(mock_container) + else: + with pytest.raises((RuntimeError, TimeoutError), match=expected_error): + strategy.wait_until_ready(mock_container) + + +class TestPortWaitStrategy: + """Test the PortWaitStrategy class.""" + + @pytest.mark.parametrize( + "port", + [ + 8080, + 80, + 443, + 22, + 3306, + ], + ids=[ + "port_8080", + "port_80", + "port_443", + "port_22", + "port_3306", + ], + ) + def test_port_wait_strategy_initialization(self, port): + strategy = PortWaitStrategy(port) + assert strategy._port == port + + @pytest.mark.parametrize( + "connection_success,expected_behavior", + [ + (True, "success"), + (False, "timeout"), + ], + ids=[ + "socket_connection_success", + "socket_connection_timeout", + ], + ) + @patch("socket.socket") + @patch("time.time") + @patch("time.sleep") + def test_wait_until_ready(self, mock_sleep, mock_time, mock_socket, connection_success, expected_behavior): + strategy = PortWaitStrategy(8080).with_startup_timeout(1) + mock_container = Mock() + mock_container.get_container_host_ip.return_value = "localhost" + mock_container.get_exposed_port.return_value = 8080 + + # Mock socket connection + mock_socket_instance = Mock() + if connection_success: + mock_socket.return_value.__enter__.return_value = mock_socket_instance + mock_time.side_effect = [0, 1] + else: + mock_socket_instance.connect.side_effect = ConnectionRefusedError() + mock_socket.return_value.__enter__.return_value = mock_socket_instance + mock_time.side_effect = [0, 2] # Exceed timeout + + if expected_behavior == "success": + strategy.wait_until_ready(mock_container) + mock_socket_instance.connect.assert_called_once_with(("localhost", 8080)) + else: + with pytest.raises(TimeoutError, match="Port 8080 not available within 1 seconds"): + strategy.wait_until_ready(mock_container) + + +class TestFileExistsWaitStrategy: + """Test the FileExistsWaitStrategy class.""" + + @pytest.mark.parametrize( + "file_path", + [ + "/tmp/test.txt", + "/var/log/app.log", + "/opt/app/config.yaml", + "relative/path/file.conf", + ], + ids=[ + "tmp_file_path", + "var_log_file_path", + "opt_config_file_path", + "relative_file_path", + ], + ) + def test_file_exists_wait_strategy_initialization(self, file_path): + strategy = FileExistsWaitStrategy(file_path) + # _file_path is stored as a Path object + assert str(strategy._file_path) == file_path + # Should inherit from WaitStrategy + assert hasattr(strategy, "_startup_timeout") + assert hasattr(strategy, "_poll_interval") + + @pytest.mark.parametrize( + "file_exists,expected_behavior", + [ + (True, "success"), + (False, "timeout"), + ], + ids=[ + "file_exists_success", + "file_not_exists_timeout", + ], + ) + @patch("pathlib.Path.is_file") + @patch("time.time") + @patch("time.sleep") + def test_wait_until_ready(self, mock_sleep, mock_time, mock_is_file, file_exists, expected_behavior): + strategy = FileExistsWaitStrategy("/tmp/test.txt").with_startup_timeout(1) + mock_container = Mock() + + # Configure mocks based on expected behavior + if file_exists: + mock_is_file.return_value = True + # Need multiple time values for debug logging + mock_time.side_effect = [0, 0.1, 0.2] # Start time, check time, logging time + else: + mock_is_file.return_value = False + # Need more time values for the loop and logging calls + mock_time.side_effect = itertools.count(start=0, step=0.6) # Exceed timeout after a few iterations + + if expected_behavior == "success": + strategy.wait_until_ready(mock_container) + mock_is_file.assert_called() + else: + with pytest.raises(TimeoutError, match="File.*did not exist within.*seconds"): + strategy.wait_until_ready(mock_container) + + +class TestCompositeWaitStrategy: + """Test the CompositeWaitStrategy class.""" + + def test_composite_wait_strategy_initialization_single_strategy(self): + """Test initialization with a single strategy.""" + log_strategy = LogMessageWaitStrategy("ready") + composite = CompositeWaitStrategy(log_strategy) + assert composite._strategies == [log_strategy] + + def test_composite_wait_strategy_initialization_multiple_strategies(self): + """Test initialization with multiple strategies.""" + log_strategy = LogMessageWaitStrategy("ready") + port_strategy = PortWaitStrategy(8080) + file_strategy = FileExistsWaitStrategy("/tmp/test.txt") + + composite = CompositeWaitStrategy(log_strategy, port_strategy, file_strategy) + assert composite._strategies == [log_strategy, port_strategy, file_strategy] + + def test_composite_wait_strategy_initialization_empty(self): + """Test that empty initialization works (creates empty list).""" + composite = CompositeWaitStrategy() + assert composite._strategies == [] + + def test_with_startup_timeout_propagates_to_child_strategies(self): + """Test that timeout setting propagates to child strategies.""" + log_strategy = LogMessageWaitStrategy("ready") + composite = CompositeWaitStrategy(log_strategy) + result = composite.with_startup_timeout(30) + assert result is composite + assert composite._startup_timeout == 30 + # Should also propagate to child strategies + assert log_strategy._startup_timeout == 30 + + def test_with_poll_interval_propagates_to_child_strategies(self): + """Test that poll interval setting propagates to child strategies.""" + port_strategy = PortWaitStrategy(8080) + composite = CompositeWaitStrategy(port_strategy) + result = composite.with_poll_interval(2.0) + assert result is composite + assert composite._poll_interval == 2.0 + # Should also propagate to child strategies + assert port_strategy._poll_interval == 2.0 + + def test_wait_until_ready_all_strategies_succeed(self): + """Test that all strategies are executed when they all succeed.""" + # Create mock strategies + strategy1 = Mock() + strategy2 = Mock() + strategy3 = Mock() + + composite = CompositeWaitStrategy(strategy1, strategy2, strategy3) + mock_container = Mock() + + # All strategies should succeed + strategy1.wait_until_ready.return_value = None + strategy2.wait_until_ready.return_value = None + strategy3.wait_until_ready.return_value = None + + composite.wait_until_ready(mock_container) + + # Verify all strategies were called in order + strategy1.wait_until_ready.assert_called_once_with(mock_container) + strategy2.wait_until_ready.assert_called_once_with(mock_container) + strategy3.wait_until_ready.assert_called_once_with(mock_container) + + def test_wait_until_ready_first_strategy_fails(self): + """Test that execution stops when first strategy fails.""" + strategy1 = Mock() + strategy2 = Mock() + strategy3 = Mock() + + composite = CompositeWaitStrategy(strategy1, strategy2, strategy3) + mock_container = Mock() + + # First strategy fails + strategy1.wait_until_ready.side_effect = TimeoutError("First strategy failed") + + with pytest.raises(TimeoutError, match="First strategy failed"): + composite.wait_until_ready(mock_container) + + # Only first strategy should be called + strategy1.wait_until_ready.assert_called_once_with(mock_container) + strategy2.wait_until_ready.assert_not_called() + strategy3.wait_until_ready.assert_not_called() + + def test_wait_until_ready_middle_strategy_fails(self): + """Test that execution stops when middle strategy fails.""" + strategy1 = Mock() + strategy2 = Mock() + strategy3 = Mock() + + composite = CompositeWaitStrategy(strategy1, strategy2, strategy3) + mock_container = Mock() + + # First succeeds, second fails + strategy1.wait_until_ready.return_value = None + strategy2.wait_until_ready.side_effect = RuntimeError("Second strategy failed") + + with pytest.raises(RuntimeError, match="Second strategy failed"): + composite.wait_until_ready(mock_container) + + # First two strategies should be called + strategy1.wait_until_ready.assert_called_once_with(mock_container) + strategy2.wait_until_ready.assert_called_once_with(mock_container) + strategy3.wait_until_ready.assert_not_called() + + @pytest.mark.parametrize( + "strategy_types,expected_count", + [ + (["log"], 1), + (["log", "port"], 2), + (["log", "port", "file"], 3), + (["file", "log", "port", "file"], 4), + ], + ids=[ + "single_log_strategy", + "log_and_port_strategies", + "three_different_strategies", + "four_strategies_with_duplicate_type", + ], + ) + def test_composite_strategy_count(self, strategy_types, expected_count): + """Test that composite strategy handles different numbers of strategies.""" + strategies: list[WaitStrategy] = [] + for strategy_type in strategy_types: + if strategy_type == "log": + strategies.append(LogMessageWaitStrategy("ready")) + elif strategy_type == "port": + strategies.append(PortWaitStrategy(8080)) + elif strategy_type == "file": + strategies.append(FileExistsWaitStrategy("/tmp/test.txt")) + + composite = CompositeWaitStrategy(*strategies) + assert len(composite._strategies) == expected_count + assert composite._strategies == strategies