From c4ee4a3c02d1deb8faac2fe8c5303eac3964b4eb Mon Sep 17 00:00:00 2001 From: monoxgas Date: Thu, 2 Oct 2025 01:40:44 -0600 Subject: [PATCH 1/3] Add pull command. Refactor docker cli mechs. --- docs/sdk/api.mdx | 90 ++- docs/sdk/main.mdx | 6 +- dreadnode/api/client.py | 28 +- dreadnode/cli/docker.py | 387 +++++++++++++ dreadnode/cli/github.py | 54 +- dreadnode/cli/main.py | 124 ++-- dreadnode/cli/platform/cli.py | 291 +++++++--- dreadnode/cli/platform/compose.py | 212 +++++++ dreadnode/cli/platform/configure.py | 47 -- dreadnode/cli/platform/constants.py | 10 +- dreadnode/cli/platform/docker_.py | 546 ------------------ dreadnode/cli/platform/download.py | 163 ++---- .../cli/platform/{utils => }/env_mgmt.py | 23 +- dreadnode/cli/platform/login.py | 20 - dreadnode/cli/platform/schemas.py | 117 ---- dreadnode/cli/platform/start.py | 79 --- dreadnode/cli/platform/status.py | 43 -- dreadnode/cli/platform/stop.py | 21 - dreadnode/cli/platform/tag.py | 40 ++ dreadnode/cli/platform/upgrade.py | 64 -- dreadnode/cli/platform/utils/__init__.py | 0 dreadnode/cli/platform/utils/printing.py | 43 -- dreadnode/cli/platform/utils/versions.py | 184 ------ dreadnode/cli/platform/version.py | 216 +++++++ dreadnode/cli/profile/cli.py | 30 +- dreadnode/cli/shared.py | 4 +- dreadnode/logging_.py | 81 +-- dreadnode/main.py | 8 +- dreadnode/user_config.py | 4 +- pyproject.toml | 2 +- tests/cli/test_config.py | 79 +++ tests/cli/test_docker.py | 161 ++++++ tests/cli/test_github.py | 203 +++++++ uv.lock | 2 +- 34 files changed, 1803 insertions(+), 1579 deletions(-) create mode 100644 dreadnode/cli/docker.py create mode 100644 dreadnode/cli/platform/compose.py delete mode 100644 dreadnode/cli/platform/configure.py delete mode 100644 dreadnode/cli/platform/docker_.py rename dreadnode/cli/platform/{utils => }/env_mgmt.py (96%) delete mode 100644 dreadnode/cli/platform/login.py delete mode 100644 dreadnode/cli/platform/schemas.py delete mode 100644 dreadnode/cli/platform/start.py delete mode 100644 dreadnode/cli/platform/status.py delete mode 100644 dreadnode/cli/platform/stop.py create mode 100644 dreadnode/cli/platform/tag.py delete mode 100644 dreadnode/cli/platform/upgrade.py delete mode 100644 dreadnode/cli/platform/utils/__init__.py delete mode 100644 dreadnode/cli/platform/utils/printing.py delete mode 100644 dreadnode/cli/platform/utils/versions.py create mode 100644 dreadnode/cli/platform/version.py create mode 100644 tests/cli/test_config.py create mode 100644 tests/cli/test_docker.py create mode 100644 tests/cli/test_github.py diff --git a/docs/sdk/api.mdx b/docs/sdk/api.mdx index 2512b872..c8ab6597 100644 --- a/docs/sdk/api.mdx +++ b/docs/sdk/api.mdx @@ -505,37 +505,6 @@ def export_timeseries( ``` - - -### get\_container\_registry\_credentials - -```python -get_container_registry_credentials() -> ( - ContainerRegistryCredentials -) -``` - -Retrieves container registry credentials for Docker image access. - -**Returns:** - -* `ContainerRegistryCredentials` - –The container registry credentials object. - - -```python -def get_container_registry_credentials(self) -> ContainerRegistryCredentials: - """ - Retrieves container registry credentials for Docker image access. - - Returns: - The container registry credentials object. - """ - response = self.request("POST", "/platform/registry-token") - return ContainerRegistryCredentials(**response.json()) -``` - - ### get\_device\_codes @@ -577,13 +546,44 @@ def get_github_access_token(self, repos: list[str]) -> GithubTokenResponse: ``` + + +### get\_platform\_registry\_credentials + +```python +get_platform_registry_credentials() -> ( + ContainerRegistryCredentials +) +``` + +Retrieves container registry credentials for Docker image access. + +**Returns:** + +* `ContainerRegistryCredentials` + –The container registry credentials object. + + +```python +def get_platform_registry_credentials(self) -> ContainerRegistryCredentials: + """ + Retrieves container registry credentials for Docker image access. + + Returns: + The container registry credentials object. + """ + response = self.request("POST", "/platform/registry-token") + return ContainerRegistryCredentials(**response.json()) +``` + + ### get\_platform\_releases ```python get_platform_releases( - tag: str, services: list[str], cli_version: str | None + tag: str, services: list[str] ) -> RegistryImageDetails ``` @@ -596,35 +596,17 @@ Resolves the platform releases for the current project. ```python -def get_platform_releases( - self, tag: str, services: list[str], cli_version: str | None -) -> RegistryImageDetails: +def get_platform_releases(self, tag: str, services: list[str]) -> RegistryImageDetails: """ Resolves the platform releases for the current project. Returns: The resolved platform releases as a ResolveReleasesResponse object. """ - payload = { - "tag": tag, - "services": services, - "cli_version": cli_version, - } - try: - response = self.request("POST", "/platform/get-releases", json_data=payload) - - except RuntimeError as e: - if "403" in str(e): - raise RuntimeError("You do not have access to platform releases.") from e - - if "404" in str(e): - if "Image not found" in str(e): - raise RuntimeError("Image not found") from e + from dreadnode.version import VERSION - raise RuntimeError( - f"Failed to get platform releases: {e}. The feature is likely disabled on this server" - ) from e - raise + payload = {"tag": tag, "services": services, "cli_version": VERSION} + response = self.request("POST", "/platform/get-releases", json_data=payload) return RegistryImageDetails(**response.json()) ``` diff --git a/docs/sdk/main.mdx b/docs/sdk/main.mdx index 40f6c8c6..be900b5e 100644 --- a/docs/sdk/main.mdx +++ b/docs/sdk/main.mdx @@ -329,11 +329,13 @@ def configure( # Log config information for clarity if self.server or self.token or self.local_dir: destination = self.server or DEFAULT_SERVER_URL or "local storage" - rich.print(f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})") + logging_console.print( + f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})" + ) # Warn the user if the profile didn't resolve elif active_profile and not (self.server or self.token): - rich.print( + logging_console.print( f":exclamation: Dreadnode profile [orange_red1]{active_profile}[/] appears invalid." ) diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py index ad10de48..181423ac 100644 --- a/dreadnode/api/client.py +++ b/dreadnode/api/client.py @@ -726,7 +726,7 @@ def get_user_data_credentials(self) -> UserDataCredentials: # Container registry access - def get_container_registry_credentials(self) -> ContainerRegistryCredentials: + def get_platform_registry_credentials(self) -> ContainerRegistryCredentials: """ Retrieves container registry credentials for Docker image access. @@ -736,35 +736,17 @@ def get_container_registry_credentials(self) -> ContainerRegistryCredentials: response = self.request("POST", "/platform/registry-token") return ContainerRegistryCredentials(**response.json()) - def get_platform_releases( - self, tag: str, services: list[str], cli_version: str | None - ) -> RegistryImageDetails: + def get_platform_releases(self, tag: str, services: list[str]) -> RegistryImageDetails: """ Resolves the platform releases for the current project. Returns: The resolved platform releases as a ResolveReleasesResponse object. """ - payload = { - "tag": tag, - "services": services, - "cli_version": cli_version, - } - try: - response = self.request("POST", "/platform/get-releases", json_data=payload) - - except RuntimeError as e: - if "403" in str(e): - raise RuntimeError("You do not have access to platform releases.") from e - - if "404" in str(e): - if "Image not found" in str(e): - raise RuntimeError("Image not found") from e + from dreadnode.version import VERSION - raise RuntimeError( - f"Failed to get platform releases: {e}. The feature is likely disabled on this server" - ) from e - raise + payload = {"tag": tag, "services": services, "cli_version": VERSION} + response = self.request("POST", "/platform/get-releases", json_data=payload) return RegistryImageDetails(**response.json()) def get_platform_templates(self, tag: str) -> bytes: diff --git a/dreadnode/cli/docker.py b/dreadnode/cli/docker.py new file mode 100644 index 00000000..83d06c9f --- /dev/null +++ b/dreadnode/cli/docker.py @@ -0,0 +1,387 @@ +import re +import subprocess # nosec +import typing as t + +from pydantic import AliasChoices, BaseModel, Field + +from dreadnode.common_types import UNSET, Unset +from dreadnode.constants import ( + DEFAULT_DOCKER_REGISTRY_IMAGE_TAG, + DEFAULT_DOCKER_REGISTRY_LOCAL_PORT, + DEFAULT_DOCKER_REGISTRY_SUBDOMAIN, + DEFAULT_PLATFORM_BASE_DOMAIN, +) +from dreadnode.logging_ import print_info +from dreadnode.user_config import ServerConfig + +DockerContainerState = t.Literal[ + "running", "exited", "paused", "restarting", "removing", "created", "dead" +] + + +class DockerError(Exception): + pass + + +class DockerImage(str): # noqa: SLOT000 + """ + A string subclass that normalizes and parses various Docker image string formats. + + Supported formats: + - ubuntu + - ubuntu:22.04 + - library/ubuntu:22.04 + - docker.io/library/ubuntu:22.04 + - myregistry:5000/my/image:latest + - myregistry:5000/my/image@sha256:f6e42a... + - dreadnode/image (correctly parsed as a Docker Hub image) + """ + + registry: str | None + repository: str + tag: str | None + digest: str | None + + def __new__(cls, value: str, *_: t.Any, **__: t.Any) -> "DockerImage": + value = value.strip() + if not value: + raise ValueError("Invalid Docker image format: input cannot be empty") + + # 1. Separate digest from the rest + digest: str | None = None + if "@" in value: + value, digest = value.split("@", 1) + + # 2. Separate tag from the repository path + tag: str | None = None + repo_path = value + # A tag is present if there's a colon that is NOT part of a port number in a registry hostname + if ":" in value: + possible_repo, possible_tag = value.rsplit(":", 1) + # If the part before the colon contains a slash, or no slash at all, it's a tag. + # This correctly handles "ubuntu:22.04" and "gcr.io/my/image:tag" but not "localhost:5000/image". + if "/" in possible_repo or "/" not in value: + repo_path, tag = possible_repo, possible_tag + + if not repo_path: + raise ValueError("Invalid Docker image format: missing repository name") + + # 3. Determine the registry and the final repository name + registry: str | None = None + repository = repo_path + + if "/" not in repo_path: + # Case 1: An official image like "ubuntu". + repository = f"library/{repo_path}" + else: + # Case 2: A namespaced path. It could be "dreadnode/image" + # or "gcr.io/google-containers/busybox". + first_part = repo_path.split("/", 1)[0] + if "." in first_part or ":" in first_part: + # If the first part has a "." or ":", it's a registry hostname. + registry = first_part + repository = repo_path.split("/", 1)[1] + + # 4. Default to 'latest' tag if no tag or digest is provided + if not tag and not digest: + tag = "latest" + + # 5. Construct the full, normalized string for the object's value + full_image_str = repository + if registry: + full_image_str = f"{registry}/{repository}" + + if tag: + full_image_str += f":{tag}" + if digest: + full_image_str += f"@{digest}" + + obj = super().__new__(cls, full_image_str) + obj.registry = registry + obj.repository = repository + obj.tag = tag + obj.digest = digest + + return obj + + def __repr__(self) -> str: + parts = [f"repository='{self.repository}'"] + if self.registry: + parts.append(f"registry='{self.registry}'") + if self.tag: + parts.append(f"tag='{self.tag}'") + if self.digest: + parts.append(f"digest='{self.digest}'") + return f"{self.__class__.__name__}({', '.join(parts)})" + + def with_( + self, + *, + repository: str | Unset = UNSET, + registry: str | None | Unset = UNSET, + tag: str | None | Unset = UNSET, + digest: str | None | Unset = UNSET, + ) -> "DockerImage": + """ + Create a new DockerImage instance with updated elements. + """ + new_registry = registry if not isinstance(registry, Unset) else self.registry + new_repository = repository if not isinstance(repository, Unset) else self.repository + new_tag = tag if not isinstance(tag, Unset) else self.tag + new_digest = digest if not isinstance(digest, Unset) else self.digest + + new_image = new_repository + if new_registry: + new_image = f"{new_registry}/{new_repository}" + + if new_tag: + new_image += f":{new_tag}" + if new_digest: + new_image += f"@{new_digest}" + + return DockerImage(new_image) + + +class DockerContainer(BaseModel): + id: str = Field(..., alias="ID") + name: str = Field(..., validation_alias=AliasChoices("Name", "Names")) + exit_code: int = Field(-1, alias="ExitCode") + state: DockerContainerState = Field(..., alias="State") + status: str = Field(..., alias="Status") + raw_ports: str = Field(..., alias="Ports") + image: str = Field(..., alias="Image") + command: str = Field(..., alias="Command") + + @property + def is_running(self) -> bool: + return self.state == "running" + + @property + def ports(self) -> list[tuple[int, int]]: + """ + Parse the raw_ports string into a list of tuples mapping host ports to container ports. + """ + ports = [] + for mapping in self.raw_ports.split(","): + host_part, container_part = mapping.split("->") + host_port = int(host_part.split(":")[-1]) + container_port = int(container_part.split("/")[0]) + ports.append((host_port, container_port)) + return ports + + +def docker_run( + args: list[str], + *, + timeout: int = 300, + stdin_input: str | None = None, + capture_output: bool = False, +) -> subprocess.CompletedProcess[str]: + """ + Execute a docker command with common error handling and configuration. + + Args: + args: Additional arguments for the docker command. + timeout: Command timeout in seconds. + stdin_input: Optional input string to pass to the command's stdin. + capture_output: Whether to capture the command's output. + + Returns: + CompletedProcess object with command results. + + Raises: + subprocess.CalledProcessError: If command fails. + subprocess.TimeoutExpired: If command times out. + FileNotFoundError: If docker/docker-compose not found. + """ + try: + result = subprocess.run( # noqa: S603 # nosec + ["docker", *args], # noqa: S607 + check=True, + text=True, + timeout=timeout, + encoding="utf-8", + errors="replace", + input=stdin_input, + capture_output=capture_output, + ) + if not capture_output: + print_info("") # Some padding after command output + + except subprocess.CalledProcessError as e: + command_str = " ".join(e.cmd) + error = f"Docker command failed: {command_str}" + if e.stderr: + error += f"\n\n{e.stderr}" + raise DockerError(error) from e + + except subprocess.TimeoutExpired as e: + command_str = " ".join(e.cmd) + raise DockerError(f"Docker command timed out after {timeout} seconds: {command_str}") from e + + except FileNotFoundError as e: + raise DockerError( + "`docker` not found, please ensure it is installed and in your PATH." + ) from e + + return result + + +def get_available_local_images() -> list[DockerImage]: + """ + Get the list of available Docker images on the local system. + + Returns: + List of available Docker image names. + """ + result = docker_run( + ["images", "--format", "{{.Repository}}:{{.Tag}}@{{.Digest}}"], + capture_output=True, + timeout=30, + ) + return [DockerImage(line) for line in result.stdout.splitlines() if line.strip()] + + +def get_env_var_from_container(container_name: str, var_name: str) -> str | None: + """ + Get the specified environment variable from the container and return + its value. + + Args: + container_name: Name of the container to inspect. + var_name: Name of the environment variable to retrieve. + + Returns: + str | None: Value of the environment variable, or None if not found. + """ + result = docker_run( + ["inspect", "-f", "{{range .Config.Env}}{{println .}}{{end}}", container_name], + capture_output=True, + timeout=30, + ) + for line in result.stdout.splitlines(): + if line.startswith(f"{var_name.upper()}="): + return line.split("=", 1)[1] + return None + + +def docker_login(registry: str, username: str, password: str) -> None: + """ + Log into a Docker registry. + + Args: + registry: Registry hostname to log into. + username: Username for the registry. + password: Password for the registry. + """ + docker_run( + ["login", registry, "--username", username, "--password-stdin"], + stdin_input=password, + capture_output=True, + timeout=60, + ) + + +def docker_ps() -> list[DockerContainer]: + """ + List and parse running containers using `docker ps`. + + Returns: + A list of DockerPSResult objects. + """ + result = docker_run( + ["ps", "--format", "json"], + capture_output=True, + ) + return [ + DockerContainer.model_validate_json(line) + for line in result.stdout.splitlines() + if line.strip() + ] + + +def docker_compose_ps(args: list[str] | None = None) -> list[DockerContainer]: + """ + List and parse running containers using `docker compose ps`. + + This mirrors: + docker compose [*args] ps --format json + + Args: + args: Additional docker compose arguments. + + Returns: + A list of DockerPSResult objects. + """ + result = docker_run( + ["compose", *(args or []), "ps", "--format", "json"], + capture_output=True, + ) + return [ + DockerContainer.model_validate_json(line) + for line in result.stdout.splitlines() + if line.strip() + ] + + +def docker_tag(image: str | DockerImage, new_tag: str) -> None: + """ + Tag a Docker image with a new tag. + + Args: + image: The name of the image to tag. + new_tag: The new tag to apply to the image. + """ + docker_run( + ["tag", str(image), new_tag], + capture_output=True, + timeout=60, + ) + + +def get_local_registry_port() -> int: + for container in docker_ps(): + if DEFAULT_DOCKER_REGISTRY_IMAGE_TAG in container.image and container.ports: + # return the first mapped port + return container.ports[0][0] + + # fallback to the default port + return DEFAULT_DOCKER_REGISTRY_LOCAL_PORT + + +def get_registry(config: ServerConfig) -> str: + # localhost is a special case + if "localhost" in config.url or "127.0.0.1" in config.url: + return f"localhost:{get_local_registry_port()}" + + prefix = "" + if "staging-" in config.url: + prefix = "staging-" + elif "dev-" in config.url: + prefix = "dev-" + + return f"{prefix}{DEFAULT_DOCKER_REGISTRY_SUBDOMAIN}.{DEFAULT_PLATFORM_BASE_DOMAIN}" + + +def docker_pull(image: str | DockerImage) -> None: + """ + Pull a Docker image. + + Args: + image: The name of the image to pull. + """ + docker_run(["pull", image]) + + +def clean_username(name: str) -> str: + """ + Sanitizes an agent or user name to be used in a Docker repository URI. + """ + # convert to lowercase + name = name.lower() + # replace non-alphanumeric characters with hyphens + name = re.sub(r"[^\w\s-]", "", name) + # replace one or more whitespace characters with a single hyphen + name = re.sub(r"[-\s]+", "-", name) + # remove leading or trailing hyphens + return name.strip("-") diff --git a/dreadnode/cli/github.py b/dreadnode/cli/github.py index 3ca6ea55..0731d03c 100644 --- a/dreadnode/cli/github.py +++ b/dreadnode/cli/github.py @@ -6,9 +6,9 @@ import zipfile import httpx -import rich from rich.prompt import Prompt +from dreadnode.logging_ import confirm, console, print_info, print_warning from dreadnode.user_config import UserConfig, find_dreadnode_saas_profiles, is_dreadnode_saas_server @@ -174,7 +174,7 @@ def download_and_unzip_archive(url: str, *, headers: dict[str, str] | None = Non temp_dir = pathlib.Path(tempfile.mkdtemp()) local_zip_path = temp_dir / "archive.zip" - rich.print(f":arrow_double_down: Downloading {url} ...") + print_info(f"Downloading {url} ...") # download to temporary file with httpx.stream("GET", url, follow_redirects=True, verify=True, headers=headers) as response: @@ -216,58 +216,50 @@ def validate_server_for_clone(user_config: UserConfig, current_profile: str | No return current_profile or user_config.active_profile_name # Current server is not a Dreadnode SaaS server - warn user - rich.print() - rich.print(":warning: [yellow]Warning: Current server is not a Dreadnode SaaS server[/]") - rich.print(f" Current server: [cyan]{current_server}[/]") - rich.print(f" Current profile: [cyan]{current_profile or user_config.active_profile_name}[/]") - rich.print() - rich.print("Git clone for private dreadnode repositories requires a Dreadnode SaaS server") - rich.print("(ending with '.dreadnode.io') for authentication to work properly.") - rich.print() + print_warning( + f"Current server is not a Dreadnode SaaS server\n" + f" Current server: [cyan]{current_server}[/]\n" + f" Current profile: [cyan]{current_profile or user_config.active_profile_name}[/]\n\n" + "Git clone for private dreadnode repositories requires a Dreadnode SaaS server\n" + "(ending with '.dreadnode.io') for authentication to work properly.\n" + ) # Check if there are any SaaS profiles available saas_profiles = find_dreadnode_saas_profiles(user_config) if saas_profiles: - rich.print("Available Dreadnode SaaS profiles:") + print_info("Available Dreadnode SaaS profiles:") for profile in saas_profiles: server_url = user_config.servers[profile].url - rich.print(f" - [green]{profile}[/] ({server_url})") - rich.print() + console.print(f" - [bold]{profile}[/] ({server_url})") choices = ["continue", "switch", "cancel"] choice = Prompt.ask( - "Choose an option", choices=choices, default="cancel", show_choices=True + "\nChoose an option", choices=choices, default="cancel", show_choices=True ) if choice == "continue": - rich.print( - ":warning: [yellow]Continuing with current server - private repository access may fail[/]" - ) + print_warning("Continuing with current server - private repository access may fail") return current_profile or user_config.active_profile_name if choice == "cancel": - rich.print("Cancelled.") return None if choice == "switch": # Let user pick a profile profile_choice = Prompt.ask( - "Select profile to use", choices=saas_profiles, default=saas_profiles[0] - ) - rich.print( - f":arrows_counterclockwise: Using profile '[green]{profile_choice}[/]' for this operation" + "\nSelect profile to use", + choices=saas_profiles, + default=saas_profiles[0], + console=console, ) + print_info(f"Using profile '[cyan]{profile_choice}[/]' for this operation") return profile_choice else: # No SaaS profiles available - choice = Prompt.ask("Continue anyway?", choices=["y", "n"], default="n") - - if choice == "y": - rich.print( - ":warning: [yellow]Continuing with current server - private repository access may fail[/]" + if not confirm("Continue anyway?"): + print_info( + "Cancelled. Use [bold]dreadnode login --server https://platform.dreadnode.io[/] to add a SaaS profile." ) - return current_profile or user_config.active_profile_name - rich.print( - "Cancelled. Use [bold]dreadnode login --server https://platform.dreadnode.io[/] to add a SaaS profile." - ) + print_warning("Continuing with current server - private repository access may fail") + return current_profile or user_config.active_profile_name return None diff --git a/dreadnode/cli/main.py b/dreadnode/cli/main.py index f1f98a05..b0a05d75 100644 --- a/dreadnode/cli/main.py +++ b/dreadnode/cli/main.py @@ -8,13 +8,12 @@ import webbrowser import cyclopts -import rich from rich.panel import Panel -from rich.prompt import Prompt from dreadnode.api.client import ApiClient from dreadnode.cli.agent import cli as agent_cli from dreadnode.cli.api import create_api_client +from dreadnode.cli.docker import DockerImage, docker_login, docker_pull, docker_tag, get_registry from dreadnode.cli.eval import cli as eval_cli from dreadnode.cli.github import ( GithubRepo, @@ -25,6 +24,7 @@ from dreadnode.cli.profile import cli as profile_cli from dreadnode.cli.study import cli as study_cli from dreadnode.constants import DEBUG, PLATFORM_BASE_URL +from dreadnode.logging_ import confirm, console, print_info, print_success from dreadnode.user_config import ServerConfig, UserConfig cli = cyclopts.App(help="Interact with Dreadnode platforms", version_flags=[], help_on_error=True) @@ -43,30 +43,29 @@ def meta( *tokens: t.Annotated[str, cyclopts.Parameter(show=False, allow_leading_hyphen=True)], ) -> None: try: - rich.print() + console.print() cli(tokens) except Exception as e: if DEBUG: raise - - rich.print() - rich.print(Panel(str(e), title="Error", title_align="left", border_style="red")) + console.print() + console.print(Panel(str(e), title="Error", title_align="left", border_style="red")) sys.exit(1) @cli.command(group="Auth") def login( *, - server: t.Annotated[ - str | None, - cyclopts.Parameter(name=["--server", "-s"], help="URL of the server"), - ] = None, - profile: t.Annotated[ - str | None, - cyclopts.Parameter(name=["--profile", "-p"], help="Profile alias to assign / update"), - ] = None, + server: t.Annotated[str | None, cyclopts.Parameter(name=["--server", "-s"])] = None, + profile: t.Annotated[str | None, cyclopts.Parameter(name=["--profile", "-p"])] = None, ) -> None: - """Authenticate to a Dreadnode platform server and save the profile.""" + """ + Authenticate to a Dreadnode platform server and save the profile. + + Args: + server: The server URL to authenticate against. + profile: The profile name to save the server configuration under. + """ if not server: server = PLATFORM_BASE_URL with contextlib.suppress(Exception): @@ -76,7 +75,7 @@ def login( # create client with no auth data client = ApiClient(base_url=server) - rich.print(":laptop_computer: Requesting device code ...") + print_info("Requesting device code ...") # request user and device codes codes = client.get_device_codes() @@ -85,16 +84,15 @@ def login( verification_url = client.url_for_user_code(codes.user_code) verification_url_base = verification_url.split("?")[0] - rich.print() - rich.print( - f"""\ -Attempting to automatically open the authorization page in your default browser. -If the browser does not open or you wish to use a different device, open the following URL: + print_info( + f""" + Attempting to automatically open the authorization page in your default browser. + If the browser does not open or you wish to use a different device, open the following URL: -:link: [bold]{verification_url_base}[/] + [bold]{verification_url_base}[/] -Then enter the code: [bold]{codes.user_code}[/] -""" + Then enter the code: [bold]{codes.user_code}[/] + """ ) webbrowser.open(verification_url) @@ -111,7 +109,8 @@ def login( ) user = client.get_user() - UserConfig.read().set_server_config( + user_config = UserConfig.read() + user_config.set_server_config( ServerConfig( url=server, access_token=tokens.access_token, @@ -121,9 +120,11 @@ def login( api_key=user.api_key.key, ), profile, - ).write() + ) + user_config.active = profile + user_config.write() - rich.print(f":white_check_mark: Authenticated as {user.email_address} ({user.username})") + print_success(f"Authenticated as {user.email_address} ({user.username})") @cli.command(group="Auth") @@ -142,32 +143,64 @@ def refresh() -> None: user_config.set_server_config(server_config).write() - rich.print( - f":white_check_mark: Refreshed '[bold]{user_config.active}[/bold]' ([magenta]{user.email_address}[/] / [cyan]{user.username}[/])" + print_success( + f"Refreshed '[bold]{user_config.active}[/bold]' ([magenta]{user.email_address}[/] / [cyan]{user.username}[/])" ) +@cli.command() +def pull(image: str) -> None: + """ + Pull a capability image from the dreadnode registry. + + Args: + image: The name of the image to pull (e.g. dreadnode/agent:latest). + """ + user_config = UserConfig.read() + if not user_config.active_profile_name: + raise RuntimeError("No server profile is set, use [bold]dreadnode login[/] to authenticate") + + server_config = user_config.get_server_config() + + docker_image = DockerImage(image) + tag_as: str | None = None + if docker_image.repository.startswith("dreadnode/") and not docker_image.registry: + docker_image = docker_image.with_( + registry=get_registry(user_config.get_server_config()), + ) + tag_as = image + + if docker_image.registry and docker_image.registry != "docker.io": + print_info(f"Authenticating to [bold]{docker_image.registry}[/] ...") + docker_login(docker_image.registry, server_config.username, server_config.api_key) + + print_info(f"Pulling image [bold]{docker_image}[/] ...") + docker_pull(docker_image) + + if tag_as: + docker_tag(docker_image, tag_as) + + @cli.command() def clone( - repo: t.Annotated[str, cyclopts.Parameter(help="Repository name or URL")], - target: t.Annotated[ - pathlib.Path | None, - cyclopts.Parameter(help="The target directory"), - ] = None, + repo: str, + target: pathlib.Path | None = None, ) -> None: - """Clone a GitHub repository to a local directory""" + """ + Clone a GitHub repository to a local directory + Args: + repo: Repository name or URL. + target: The target directory. + """ github_repo = GithubRepo(repo) # Check if the target directory exists target = target or pathlib.Path(github_repo.repo) if target.exists(): - if ( - Prompt.ask(f":axe: Overwrite {target.absolute()}?", choices=["y", "n"], default="n") - == "n" - ): + if not confirm(f"{target.absolute()} exists, overwrite?"): return - rich.print() + console.print() shutil.rmtree(target) # Check if the repo is accessible @@ -187,7 +220,7 @@ def clone( github_access_token = create_api_client(profile=profile_to_use).get_github_access_token( [github_repo.repo] ) - rich.print(":key: Accessed private repository") + print_info("Accessed private repository") temp_dir = download_and_unzip_archive( github_repo.api_zip_url, headers={"Authorization": f"Bearer {github_access_token.token}"}, @@ -204,8 +237,7 @@ def clone( shutil.move(temp_dir, target) - rich.print() - rich.print(f":tada: Cloned [b]{repo}[/] to [b]{target.absolute()}[/]") + print_success(f"Cloned [b]{repo}[/] to [b]{target.absolute()}[/]") @cli.command(help="Show versions and exit.", group="Meta") @@ -215,6 +247,6 @@ def version() -> None: os_name = platform.system() arch = platform.machine() - rich.print(f"Platform: {os_name} ({arch})") - rich.print(f"Python: {python_version}") - rich.print(f"Dreadnode: {version}") + print_info(f"Platform: {os_name} ({arch})") + print_info(f"Python: {python_version}") + print_info(f"Dreadnode: {version}") diff --git a/dreadnode/cli/platform/cli.py b/dreadnode/cli/platform/cli.py index 89879b2f..2e261d7c 100644 --- a/dreadnode/cli/platform/cli.py +++ b/dreadnode/cli/platform/cli.py @@ -2,100 +2,199 @@ import cyclopts -from dreadnode.cli.platform.configure import configure_platform, list_configurations +from dreadnode.cli.docker import ( + DockerError, + get_env_var_from_container, +) +from dreadnode.cli.platform.compose import ( + build_compose_override_file, + compose_down, + compose_login, + compose_logs, + compose_up, + platform_is_running, +) +from dreadnode.cli.platform.constants import PLATFORM_SERVICES from dreadnode.cli.platform.download import download_platform -from dreadnode.cli.platform.login import log_into_registries -from dreadnode.cli.platform.start import start_platform -from dreadnode.cli.platform.status import platform_status -from dreadnode.cli.platform.stop import stop_platform -from dreadnode.cli.platform.upgrade import upgrade_platform -from dreadnode.cli.platform.utils.printing import print_info -from dreadnode.cli.platform.utils.versions import get_current_version +from dreadnode.cli.platform.env_mgmt import ( + build_env_file, + read_env_file, + remove_overrides_env, + write_overrides_env, +) +from dreadnode.cli.platform.tag import tag_to_semver +from dreadnode.cli.platform.version import VersionConfig +from dreadnode.logging_ import confirm, print_error, print_info, print_success, print_warning cli = cyclopts.App("platform", help="Run and manage the platform.", help_flags=[]) @cli.command() -def start( - tag: t.Annotated[ - str | None, cyclopts.Parameter(help="Optional image tag to use when starting the platform.") - ] = None, - **env_overrides: t.Annotated[ - str, - cyclopts.Parameter( - help="Environment variable overrides. Use --key value format. " - "Examples: --proxy-host myproxy.local" - ), - ], -) -> None: - """Start the platform. Optionally, provide a tagged version to start. +def start(tag: str | None = None, **env_overrides: str) -> None: + """ + Start the platform. Args: - tag: Optional image tag to use when starting the platform. - **env_overrides: Key-value pairs to override environment variables in the + tag: Image tag to use when starting the platform. + env_overrides: Key-value pairs to override environment variables in the platform's .env file. e.g `--proxy-host myproxy.local` """ - start_platform(tag=tag, **env_overrides) + version_config = VersionConfig.read() + version = version_config.get_current_version(tag=tag) or download_platform(tag) + version_config.set_current_version(version) + + if platform_is_running(version): + print_info(f"Platform {version.tag} is already running.") + print_info("Use `dreadnode platform stop` to stop it first.") + return + + compose_login(version) + + if env_overrides: + write_overrides_env(version.arg_overrides_env_file, **env_overrides) + + print_info(f"Starting platform [cyan]{version.tag}[/] ...") + try: + compose_up(version) + print_success("Platform started.") + origin = get_env_var_from_container("dreadnode-ui", "ORIGIN") + if origin: + print_info("You can access the app at the following URLs:") + print_info(f" - {origin}") + else: + print_info(" - Unable to determine the app URL.") + print_info("Please check the container logs for more information.") + except DockerError as e: + compose_logs(version, tail=10) + print_error(str(e)) @cli.command(name=["stop", "down"]) -def stop() -> None: - """Stop the running platform.""" - stop_platform() +def stop(*, remove_volumes: t.Annotated[bool, cyclopts.Parameter(negative=False)] = False) -> None: + """ + Stop the running platform. + + Args: + remove_volumes: Also remove Docker volumes associated with the platform. + """ + version = VersionConfig.read().get_current_version() + if not version: + print_error("No current version found. Nothing to stop.") + return + + remove_overrides_env(version.arg_overrides_env_file) + compose_down(version, remove_volumes=remove_volumes) + print_success("Platform stopped.") @cli.command() -def download( - tag: t.Annotated[ - str | None, cyclopts.Parameter(help="Optional image tag to use when starting the platform.") - ] = None, -) -> None: - """Download platform files for a specific tag. +def logs(tail: int = 100) -> None: + """ + View the platform logs. Args: - tag: Optional image tag to download. + tail: Number of lines to show from the end of the logs for each service. + """ + version = VersionConfig.read().get_current_version() + if not version: + print_error("No current version found. Nothing to show logs for.") + return + + compose_logs(version, tail=tail) + + +@cli.command() +def download(tag: str | None = None) -> None: + """ + Download platform files for a specific tag. + + Args: + tag: Specific version tag to download. """ download_platform(tag=tag) @cli.command() def upgrade() -> None: - """Upgrade the platform to the latest version.""" - upgrade_platform() + """ + Upgrade the platform to the latest available version. + + Downloads the latest version, compares it with the current version, + and performs the upgrade if a newer version is available. Optionally + merges configuration files from the current version to the new version. + Stops the current platform and starts the upgraded version. + """ + version_config = VersionConfig.read() + current_version = version_config.get_current_version() + if not current_version: + start() + return + + latest_version = download_platform() + + current_semver = tag_to_semver(current_version.tag) + remote_semver = tag_to_semver(latest_version.tag) + + if current_semver >= remote_semver: + print_info(f"You are using the latest ({current_semver}) version of the platform.") + return + + if not confirm( + f"Upgrade from [cyan]{current_version.tag}[/] -> [magenta]{latest_version.tag}[/]?" + ): + return + + version_config.set_current_version(latest_version) + + # copy the configuration overrides from the current version to the new version + if ( + current_version.configure_overrides_compose_file.exists() + and current_version.configure_overrides_env_file.exists() + ): + latest_version.configure_overrides_compose_file.write_text( + current_version.configure_overrides_compose_file.read_text() + ) + latest_version.configure_overrides_env_file.write_text( + current_version.configure_overrides_env_file.read_text() + ) + + print_info("Stopping current platform ...") + compose_down(current_version) + compose_up(latest_version) + print_success(f"Platform upgraded to version [magenta]{latest_version.tag}[/].") @cli.command() def refresh_registry_auth() -> None: - """Refresh container registry credentials for platform access. + """ + Refresh container registry credentials for platform access. Used for out of band Docker management. """ - log_into_registries() + current_version = VersionConfig.read().get_current_version() + if not current_version: + print_info("There are no registries configured. Run `dreadnode platform start` to start.") + return + + compose_login(current_version, force=True) @cli.command() def configure( - *args: t.Annotated[ - str, - cyclopts.Parameter( - help="Key-value pairs to set. Must be provided in pairs (key value key value ...). ", - ), - ], - tag: t.Annotated[ - str | None, cyclopts.Parameter(help="Optional image tag to use when starting the platform.") - ] = None, + *args: str, + tag: str | None = None, list: t.Annotated[ bool, - cyclopts.Parameter( - ["--list", "-l"], help="List current configuration without making changes." - ), + cyclopts.Parameter(["--list", "-l"], negative=False), ] = False, unset: t.Annotated[ bool, - cyclopts.Parameter(["--unset", "-u"], help="Remove the specified configuration."), + cyclopts.Parameter(["--unset", "-u"], negative=False), ] = False, ) -> None: - """Configure the platform for a specific service. + """ + Configure the platform for a specific service. + Configurations will take effect the next time the platform is started and are persisted. Usage: platform configure KEY VALUE [KEY2 VALUE2 ...] @@ -104,14 +203,28 @@ def configure( platform configure proxy-host myproxy.local api-port 8080 Args: - *args: Key-value pairs to set. Must be provided in pairs (key value key value ...). + args: Key-value pairs to set. Must be provided in pairs (key value key value ...). tag: Optional image tag to use when starting the platform. + list: List current configuration without making changes. + unset: Remove the specified configuration. """ + current_version = VersionConfig.read().get_current_version(tag=tag) + if not current_version: + print_info("No current platform version is set. Please start or download the platform.") + return + if list: - if args: - raise ValueError("The --list option does not take any positional arguments.") - list_configurations() + overrides_env_file = current_version.configure_overrides_env_file + if not overrides_env_file.exists(): + print_info("No configuration overrides found.") + return + + print_info(f"Configuration overrides from {overrides_env_file}:") + env_vars = read_env_file(overrides_env_file) + for key, value in env_vars.items(): + print_info(f" - {key}={value}") return + # Parse positional arguments into key-value pairs if not unset and len(args) % 2 != 0: raise ValueError( @@ -122,43 +235,57 @@ def configure( env_overrides = {} for i in range(0, len(args), 2): key = args[i] - value = args[i + 1] if not unset else None - env_overrides[key] = value + env_overrides[key] = args[i + 1] if not unset else None - configure_platform(tag=tag, **env_overrides) + if not env_overrides: + print_warning("No configuration changes specified.") + return + + print_info("Setting environment overrides ...") + build_compose_override_file(PLATFORM_SERVICES, current_version) + build_env_file(current_version.configure_overrides_env_file, **env_overrides) + print_info( + f"Configuration written to {current_version.local_path}.\n\n" + "These will take effect the next time the platform is started. " + "You can modify or remove them at any time." + ) @cli.command() def version( verbose: t.Annotated[ # noqa: FBT002 - bool, - cyclopts.Parameter( - ["--verbose", "-v"], help="Display detailed information for the version." - ), + bool, cyclopts.Parameter(["--verbose", "-v"]) ] = False, ) -> None: - """Show the current platform version.""" - version = get_current_version() - if version: - if verbose: - print_info(version.details) - else: - print_info(f"Current platform version: {version!s}") + """ + Show the current platform version. - else: + Args: + verbose: Display detailed information about the version. + """ + version_config = VersionConfig.read() + version = version_config.get_current_version() + if version is None: print_info("No current platform version is set.") + return + print_info(f"Current version: [cyan]{version}[/]") + if verbose: + print_info(version.details) -@cli.command() -def status( - tag: t.Annotated[ - str | None, cyclopts.Parameter(help="Optional image tag to use when checking status.") - ] = None, -) -> None: - """Get the status of the platform with the specified or current version. - Args: - tag: Optional image tag to use. If not provided, uses the current - version or downloads the latest available version. +@cli.command() +def status() -> None: """ - platform_status(tag=tag) + Get the current status of the platform. + """ + version_config = VersionConfig.read() + version = version_config.get_current_version() + if version is None: + print_error("No current platform version is set. Please start or download the platform.") + return + + if platform_is_running(version): + print_success(f"Platform {version.tag} is running.") + else: + print_error(f"Platform {version.tag} is not fully running.") diff --git a/dreadnode/cli/platform/compose.py b/dreadnode/cli/platform/compose.py new file mode 100644 index 00000000..37813822 --- /dev/null +++ b/dreadnode/cli/platform/compose.py @@ -0,0 +1,212 @@ +import typing as t + +import yaml +from yaml import safe_dump + +from dreadnode.cli.api import create_api_client +from dreadnode.cli.docker import ( + DockerImage, + docker_compose_ps, + docker_login, + docker_run, + get_available_local_images, +) +from dreadnode.cli.platform.constants import PlatformService +from dreadnode.cli.platform.env_mgmt import read_env_file +from dreadnode.cli.platform.version import LocalVersion +from dreadnode.logging_ import print_info + + +def build_compose_override_file( + services: list[PlatformService], + version: LocalVersion, +) -> None: + # build a yaml docker compose override file + # that only includes the service being configured + # and has an `env_file` attribute for the service + override = { + "services": { + f"{service}": {"env_file": [version.configure_overrides_env_file.as_posix()]} + for service in services + }, + } + + with version.configure_overrides_compose_file.open("w") as f: + safe_dump(override, f, sort_keys=False) + + +def get_required_images(version: LocalVersion) -> list[DockerImage]: + """ + Get the list of required Docker images for the specified platform version. + + Args: + version: The selected version of the platform. + + Returns: + list[str]: List of required Docker image names. + """ + result = docker_run( + ["compose", *get_compose_args(version), "config", "--images"], + timeout=120, + capture_output=True, + ) + + if result.returncode != 0: + return [] + + return [DockerImage(line) for line in result.stdout.splitlines() if line.strip()] + + +def get_required_services(version: LocalVersion) -> list[str]: + """Get the list of required services from the docker-compose file. + + Returns: + list[str]: List of required service names. + """ + contents: dict[str, object] = yaml.safe_load(version.compose_file.read_text()) + services = t.cast("dict[str, object]", contents.get("services", {}) or {}) + return [ + name + for name, cfg in services.items() + if isinstance(cfg, dict) and "x-required" in cfg and cfg["x-required"] is True + ] + + +def get_profiles_to_enable(version: LocalVersion) -> list[str]: + """ + Get the list of profiles to enable based on environment variables. + + If any of the `x-profile-disabled-vars` are set in the environment, + the profile will be disabled. + + E.g. + + services: + myservice: + image: myimage:latest + profiles: + - myprofile + x-profile-override-vars: + - MY_SERVICE_HOST + + If MY_SERVICE_HOST is set in the environment, the `myprofile` profile + will NOT be excluded from the docker compose --profile cmd. + """ + + contents: dict[str, object] = yaml.safe_load(version.compose_file.read_text()) + services = t.cast("dict[str, object]", contents.get("services", {}) or {}) + profiles_to_enable: set[str] = set() + for service in services.values(): + if not isinstance(service, dict): + continue + + profiles = service.get("profiles", []) + if not profiles or not isinstance(profiles, list): + continue + + x_override_vars = service.get("x-profile-override-vars", []) + if not x_override_vars or not isinstance(x_override_vars, list): + profiles_to_enable.update(profiles) + continue + + configuration_file = version.configure_overrides_env_file + overrides_file = version.arg_overrides_env_file + + env_vars: dict[str, str] = {} + if configuration_file.exists(): + env_vars.update(read_env_file(configuration_file)) + if overrides_file.exists(): + env_vars.update(read_env_file(overrides_file)) + + # check if any of the override vars are set in the env + if any(var in env_vars for var in x_override_vars): + continue # skip enabling this profile + + profiles_to_enable.update(profiles) + + return list(profiles_to_enable) + + +def get_compose_args( + version: LocalVersion, *, project_name: str = "dreadnode-platform" +) -> list[str]: + command = ["-p", project_name] + compose_files = [version.compose_file] + env_files = [version.api_env_file, version.ui_env_file] + + if ( + version.configure_overrides_compose_file.exists() + and version.configure_overrides_env_file.exists() + ): + compose_files.append(version.configure_overrides_compose_file) + env_files.append(version.configure_overrides_env_file) + + for compose_file in compose_files: + command.extend(["-f", compose_file.as_posix()]) + + for profile in get_profiles_to_enable(version): + command.extend(["--profile", profile]) + + if version.arg_overrides_env_file.exists(): + env_files.append(version.arg_overrides_env_file) + + for env_file in env_files: + command.extend(["--env-file", env_file.as_posix()]) + + return command + + +def platform_is_running(version: LocalVersion) -> bool: + """ + Check if the platform with the specified or current version is running. + + Args: + version: LocalVersionSchema of the platform to check. + """ + containers = docker_compose_ps(get_compose_args(version)) + if not containers: + return False + + for service in get_required_services(version): + if service not in [c.name for c in containers if c.status == "running"]: + return False + + return True + + +def compose_up(version: LocalVersion) -> None: + docker_run( + ["compose", *get_compose_args(version), "up", "-d"], + ) + + +def compose_down(version: LocalVersion, *, remove_volumes: bool = False) -> None: + args = ["compose", *get_compose_args(version), "down"] + if remove_volumes: + args.append("--volumes") + docker_run(args) + + +def compose_logs(version: LocalVersion, *, tail: int = 100) -> None: + docker_run(["compose", *get_compose_args(version), "logs", "--tail", str(tail)]) + + +def compose_login(version: LocalVersion, *, force: bool = False) -> None: + # check to see if all required images are available locally + required_images = get_required_images(version) + available_images = get_available_local_images() + missing_images = [img for img in required_images if img not in available_images] + if not missing_images and not force: + return + + client = create_api_client() + registry_credentials = client.get_platform_registry_credentials() + + registries_attempted = set() + for image in version.images: + if image.registry not in registries_attempted: + print_info(f"Logging in to Docker registry: {image.registry} ...") + docker_login( + image.registry, registry_credentials.username, registry_credentials.password + ) + registries_attempted.add(image.registry) diff --git a/dreadnode/cli/platform/configure.py b/dreadnode/cli/platform/configure.py deleted file mode 100644 index 171c9666..00000000 --- a/dreadnode/cli/platform/configure.py +++ /dev/null @@ -1,47 +0,0 @@ -from dreadnode.cli.platform.constants import SERVICES -from dreadnode.cli.platform.docker_ import build_docker_compose_override_file -from dreadnode.cli.platform.utils.env_mgmt import build_env_file, read_env_file -from dreadnode.cli.platform.utils.printing import print_info -from dreadnode.cli.platform.utils.versions import get_current_version, get_local_version - - -def list_configurations() -> None: - """List the current platform configuration overrides, if any.""" - current_version = get_current_version() - if not current_version: - print_info("No current platform version is set. Please start or download the platform.") - return - - overrides_env_file = current_version.configure_overrides_env_file - if not overrides_env_file.exists(): - print_info("No configuration overrides found.") - return - - print_info(f"Configuration overrides from {overrides_env_file}:") - env_vars = read_env_file(overrides_env_file) - for key, value in env_vars.items(): - print_info(f" - {key}={value}") - - -def configure_platform(tag: str | None = None, **env_overrides: str | None) -> None: - """Configure the platform for a specific service. - - Args: - service: The name of the service to configure. - """ - selected_version = get_local_version(tag) if tag else get_current_version() - # No need to mark current version on configure - - if not selected_version: - print_info("No current platform version is set. Please start or download the platform.") - return - - if env_overrides: - print_info("Setting environment overrides...") - build_docker_compose_override_file(SERVICES, selected_version) - build_env_file(selected_version.configure_overrides_env_file, **env_overrides) - print_info( - f"Configuration written to {selected_version.local_path}. " - "These will take effect the next time the platform is started." - " You can modify or remove them at any time." - ) diff --git a/dreadnode/cli/platform/constants.py b/dreadnode/cli/platform/constants.py index 0b332953..5fe054d9 100644 --- a/dreadnode/cli/platform/constants.py +++ b/dreadnode/cli/platform/constants.py @@ -1,12 +1,16 @@ import typing as t +from dreadnode.constants import DEFAULT_LOCAL_STORAGE_DIR + PlatformService = t.Literal["dreadnode-api", "dreadnode-ui"] +PLATFORM_SERVICES = t.cast("list[PlatformService]", t.get_args(PlatformService)) API_SERVICE: PlatformService = "dreadnode-api" UI_SERVICE: PlatformService = "dreadnode-ui" -SERVICES: list[PlatformService] = [API_SERVICE, UI_SERVICE] -VERSIONS_MANIFEST = "versions.json" SupportedArchitecture = t.Literal["amd64", "arm64"] -SUPPORTED_ARCHITECTURES: list[SupportedArchitecture] = ["amd64", "arm64"] +SUPPORTED_ARCHITECTURES = t.cast("list[SupportedArchitecture]", t.get_args(SupportedArchitecture)) DEFAULT_DOCKER_PROJECT_NAME = "dreadnode-platform" + +PLATFORM_STORAGE_DIR = DEFAULT_LOCAL_STORAGE_DIR / "platform" +VERSION_CONFIG_PATH = PLATFORM_STORAGE_DIR / "versions.json" diff --git a/dreadnode/cli/platform/docker_.py b/dreadnode/cli/platform/docker_.py deleted file mode 100644 index c44587d6..00000000 --- a/dreadnode/cli/platform/docker_.py +++ /dev/null @@ -1,546 +0,0 @@ -import json -import subprocess -import typing as t -from dataclasses import dataclass -from enum import Enum - -import yaml -from pydantic import BaseModel, Field -from yaml import safe_dump - -from dreadnode.cli.api import create_api_client -from dreadnode.cli.platform.constants import DEFAULT_DOCKER_PROJECT_NAME, PlatformService -from dreadnode.cli.platform.schemas import LocalVersionSchema -from dreadnode.cli.platform.utils.env_mgmt import read_env_file -from dreadnode.cli.platform.utils.printing import print_error, print_info, print_success - -DockerContainerState = t.Literal[ - "running", "exited", "paused", "restarting", "removing", "created", "dead" -] - - -# create a DockerError exception that I can catch -class DockerError(Exception): - pass - - -class CaptureOutput(str, Enum): - TRUE = "true" - FALSE = "false" - - -@dataclass -class DockerImage: - repository: str - tag: str | None = None - digest: str | None = None - - @classmethod - def from_string(cls, image_string: str) -> "DockerImage": - """ - Parse a Docker image string into repository, tag, and SHA components. - - Examples: - - postgres:16 -> repository="postgres", tag="16", sha=None - - minio/minio:latest -> repository="minio/minio", tag="latest", sha=None - - image@sha256:abc123 -> repository="image", tag=None, sha="sha256:abc123" - """ - # Check if there's a SHA digest (contains @) - if "@" in image_string: - repo_part, sha = image_string.split("@", 1) - # Check if there's also a tag before the @ - if ":" in repo_part: - repository, tag = repo_part.rsplit(":", 1) - return cls(repository=repository, tag=tag, digest=sha) - return cls(repository=repo_part, tag=None, digest=sha) - - # Check if there's a tag (contains :) - if ":" in image_string: - # Use rsplit to handle cases like registry.com:5000/image:tag - repository, tag = image_string.rsplit(":", 1) - return cls(repository=repository, tag=tag, digest=None) - - # Just repository name - return cls(repository=image_string, tag=None, digest=None) - - def __str__(self) -> str: - """Reconstruct the original image string format.""" - result = self.repository - if self.tag: - result += f":{self.tag}" - if self.digest: - result += f"@{self.digest}" - return result - - def __eq__(self, other: object) -> bool: - """Check if two DockerImage instances are equal. - - If they both have digests, compare digests. - If they both have tags, compare tags. - - """ - if not isinstance(other, DockerImage): - return False - if self.repository != other.repository: - return False - if self.digest and other.digest: - return self.digest == other.digest - if self.tag and other.tag: - return self.tag == other.tag - return False - - def __ne__(self, other: object) -> bool: - """Check if two DockerImage instances are not equal.""" - return not self.__eq__(other) - - def __hash__(self) -> int: - """Generate a hash for the DockerImage instance.""" - if self.tag: - return hash((self.repository, self.tag)) - if self.digest: - return hash((self.repository, self.digest)) - return hash((self.repository,)) - - -class DockerPSResult(BaseModel): - name: str = Field(..., alias="Name") - exit_code: int = Field(..., alias="ExitCode") - state: DockerContainerState = Field(..., alias="State") - status: str = Field(..., alias="Status") - - @property - def is_running(self) -> bool: - return self.state == "running" - - -def _build_docker_compose_base_command( - selected_version: LocalVersionSchema, -) -> list[str]: - cmds = [] - compose_files = [selected_version.compose_file] - env_files = [ - selected_version.api_env_file, - selected_version.ui_env_file, - ] - - if ( - selected_version.configure_overrides_compose_file.exists() - and selected_version.configure_overrides_env_file.exists() - ): - compose_files.append(selected_version.configure_overrides_compose_file) - env_files.append(selected_version.configure_overrides_env_file) - - for compose_file in compose_files: - cmds.extend(["-f", compose_file.as_posix()]) - - for profile in _get_profiles_to_enable(selected_version): - cmds.extend(["--profile", profile]) - - if selected_version.arg_overrides_env_file.exists(): - env_files.append(selected_version.arg_overrides_env_file) - - for env_file in env_files: - cmds.extend(["--env-file", env_file.as_posix()]) - return cmds - - -def _check_docker_installed() -> bool: - """Check if Docker is installed on the system.""" - try: - cmd = ["docker", "--version"] - subprocess.run( # noqa: S603 - cmd, - check=True, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - - except subprocess.CalledProcessError: - print_error("Docker is not installed. Please install Docker and try again.") - return False - - return True - - -def _check_docker_compose_installed() -> bool: - """Check if Docker Compose is installed on the system.""" - try: - cmd = ["docker", "compose", "--version"] - subprocess.run( # noqa: S603 - cmd, - check=True, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - except subprocess.CalledProcessError: - print_error("Docker Compose is not installed. Please install Docker Compose and try again.") - return False - return True - - -def get_required_service_names(selected_version: LocalVersionSchema) -> list[str]: - """Get the list of require service names from the docker-compose file.""" - contents: dict[str, t.Any] = yaml.safe_load(selected_version.compose_file.read_text()) - services = contents.get("services", {}) or {} - return [name for name, cfg in services.items() if isinstance(cfg, dict) and "x-required" in cfg] - - -def _get_profiles_to_enable(selected_version: LocalVersionSchema) -> list[str]: - """Get the list of profiles to enable based on environment variables. - - If any of the `x-profile-disabled-vars` are set in the environment, - the profile will be disabled. - - E.g. - - services: - myservice: - image: myimage:latest - profiles: - - myprofile - x-profile-override-vars: - - MY_SERVICE_HOST - - If MY_SERVICE_HOST is set in the environment, the `myprofile` profile - will NOT be excluded from the docker compose --profile cmd. - - Args: - selected_version: The selected version of the platform. - - Returns: - List of profile names to enable. - """ - - contents: dict[str, t.Any] = yaml.safe_load(selected_version.compose_file.read_text()) - services = contents.get("services", {}) or {} - profiles_to_enable: set[str] = set() - for service in services.values(): - if not isinstance(service, dict): - continue - profiles = service.get("profiles", []) - if not profiles or not isinstance(profiles, list): - continue - x_override_vars = service.get("x-profile-override-vars", []) - if not x_override_vars or not isinstance(x_override_vars, list): - profiles_to_enable.update(profiles) - continue - - configuration_file = selected_version.configure_overrides_env_file - overrides_file = selected_version.arg_overrides_env_file - env_vars = {} - if configuration_file.exists(): - env_vars.update(read_env_file(configuration_file)) - if overrides_file.exists(): - env_vars.update(read_env_file(overrides_file)) - # check if any of the override vars are set in the env - if any(var in env_vars for var in x_override_vars): - continue # skip enabling this profile - profiles_to_enable.update(profiles) - - return list(profiles_to_enable) - - -def _run_docker_compose_command( - args: list[str], - timeout: int = 300, - stdin_input: str | None = None, - capture_output: CaptureOutput | None = None, -) -> subprocess.CompletedProcess[str]: - """Execute a docker compose command with common error handling and configuration. - - Args: - args: Additional arguments for the docker compose command. - compose_file: Path to docker-compose file. - timeout: Command timeout in seconds. - command_name: Name of the command for error messages. - stdin_input: Input to pass to stdin (for commands like docker login). - - Returns: - CompletedProcess object with command results. - - Raises: - subprocess.CalledProcessError: If command fails. - subprocess.TimeoutExpired: If command times out. - FileNotFoundError: If docker/docker-compose not found. - """ - cmd = ["docker", "compose"] - - cmd.extend(["-p", DEFAULT_DOCKER_PROJECT_NAME]) - - # Add the specific command arguments - cmd.extend(args) - - cmd_str = " ".join(cmd) - - try: - # Remove capture_output=True to allow real-time streaming - # stdout and stderr will go directly to the terminal - result = subprocess.run( # noqa: S603 - cmd, - check=True, - text=True, - timeout=timeout, - encoding="utf-8", - errors="replace", - input=stdin_input, - capture_output=bool(capture_output == CaptureOutput.TRUE), - ) - - except subprocess.CalledProcessError as e: - print_error(f"{cmd_str} failed with exit code {e.returncode}") - raise DockerError(f"Docker command failed: {e}") from e - - except subprocess.TimeoutExpired as e: - print_error(f"{cmd_str} timed out after {timeout} seconds") - raise DockerError(f"Docker command timed out after {timeout} seconds") from e - - except FileNotFoundError as e: - print_error("Docker or docker compose not found. Please ensure Docker is installed.") - raise DockerError(f"Docker compose file not found: {e}") from e - - return result - - -def build_docker_compose_override_file( - services: list[PlatformService], - selected_version: LocalVersionSchema, -) -> None: - # build a yaml docker compose override file - # that only includes the service being configured - # and has an `env_file` attribute for the service - override = { - "services": { - f"{service}": {"env_file": [selected_version.configure_overrides_env_file.as_posix()]} - for service in services - }, - } - - with selected_version.configure_overrides_compose_file.open("w") as f: - safe_dump(override, f, sort_keys=False) - - -def get_available_local_images() -> list[DockerImage]: - """Get the list of available Docker images on the local system. - - Returns: - list[str]: List of available Docker image names. - """ - cmd = ["docker", "images", "--format", "{{.Repository}}:{{.Tag}}@{{.Digest}}"] - cp = subprocess.run( # noqa: S603 - cmd, - check=True, - text=True, - capture_output=True, - ) - images: list[DockerImage] = [] - for line in cp.stdout.splitlines(): - if line.strip(): - img = DockerImage.from_string(line.strip()) - images.append(img) - return images - - -def get_env_var_from_container(container_name: str, var_name: str) -> str | None: - """ - Get the specified environment variable from the container and return - its value. - - Args: - container_name: Name of the container to inspect. - var_name: Name of the environment variable to retrieve. - - Returns: - str | None: Value of the environment variable, or None if not found. - """ - try: - cmd = [ - "docker", - "inspect", - "-f", - "{{range .Config.Env}}{{println .}}{{end}}", - container_name, - ] - cp = subprocess.run( # noqa: S603 - cmd, - check=True, - text=True, - capture_output=True, - ) - - for line in cp.stdout.splitlines(): - if line.startswith(f"{var_name.upper()}="): - return line.split("=", 1)[1] - - except subprocess.CalledProcessError: - return None - - return None - - -def get_required_images(selected_version: LocalVersionSchema) -> list[DockerImage]: - """Get the list of required Docker images for the specified platform version. - - Args: - selected_version: The selected version of the platform. - - Returns: - list[str]: List of required Docker image names. - """ - base = _build_docker_compose_base_command(selected_version) - args = [*base, "config", "--images"] - result = _run_docker_compose_command( - args, - timeout=120, - capture_output=CaptureOutput.TRUE, - ) - - if result.returncode != 0: - return [] - - required_images: list[DockerImage] = [] - for line in result.stdout.splitlines(): - if not line.strip(): - continue - # Validate each line is a valid Docker image string - DockerImage.from_string(line.strip()) - required_images.append(DockerImage.from_string(line.strip())) - - return required_images - - -def docker_requirements_met() -> bool: - """Check if Docker and Docker Compose are installed.""" - return _check_docker_installed() and _check_docker_compose_installed() - - -def docker_login(registry: str) -> None: - """Log into a Docker registry using API credentials. - - Args: - registry: Registry hostname to log into. - - Raises: - subprocess.CalledProcessError: If docker login command fails. - """ - - print_info(f"Logging in to Docker registry: {registry} ...") - client = create_api_client() - container_registry_creds = client.get_container_registry_credentials() - - cmd = ["docker", "login", container_registry_creds.registry] - cmd.extend(["--username", container_registry_creds.username]) - cmd.extend(["--password-stdin"]) - - try: - subprocess.run( # noqa: S603 - cmd, - input=container_registry_creds.password, - text=True, - check=True, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - print_success("Logged in to container registry ...") - except subprocess.CalledProcessError as e: - print_error(f"Failed to log in to container registry: {e}") - raise - - -def docker_ps( - selected_version: LocalVersionSchema, - timeout: int = 120, -) -> list[DockerPSResult]: - """Get container status for the compose project as JSON. - - This mirrors: - docker compose -f <...> -f <...> --env-file <...> --env-file <...> ps --format json [SERVICE...] - - Args: - selected_version: Version object providing compose/env files. - services: Optional list of PlatformService to filter (translated to 'platform-'). - timeout: Command timeout in seconds. - - Returns: - A list of dicts parsed from `docker compose ps --format json`. - - Raises: - ValueError: If the returned output is not valid JSON. - subprocess.CalledProcessError / TimeoutExpired / FileNotFoundError: On execution errors. - """ - base = _build_docker_compose_base_command(selected_version) - args = [*base, "ps", "--format", "json"] - - result = _run_docker_compose_command( - args, - timeout=timeout, - capture_output=CaptureOutput.TRUE, - ) - - try: - # docker compose ps --format json returns a JSON array - if not result.stdout: - return [] - stdout = str(result.stdout) - stdout_lines = stdout.splitlines() - container_info_models: list[DockerPSResult] = [] - for line in stdout_lines: - if not line.strip(): - continue - j = json.loads(line) - dpr = DockerPSResult(**j) - container_info_models.append(dpr) - except json.JSONDecodeError as e: - print_error(f"Failed read status from the Dreadnode Platform': {e}") - raise ValueError("Unexpected non-JSON output from 'docker compose ps'") from e - - return container_info_models - - -def docker_run( - selected_version: LocalVersionSchema, - timeout: int = 300, -) -> subprocess.CompletedProcess[str]: - """Run docker containers for the platform. - - Args: - compose_file: Path to docker-compose file. - timeout: Command timeout in seconds. - - Returns: - CompletedProcess object with command results. - - Raises: - subprocess.CalledProcessError: If command fails. - subprocess.TimeoutExpired: If command times out. - """ - cmds = _build_docker_compose_base_command(selected_version) - - # Apply the compose and env override files in priority order - # 1. base compose file and env files - # 2. configure overrides compose and env files (if any) - # 3. arg overrides env file (if any) - - cmds += ["up", "-d"] - return _run_docker_compose_command(cmds, timeout=timeout) - - -def docker_stop( - selected_version: LocalVersionSchema, - timeout: int = 300, -) -> subprocess.CompletedProcess[str]: - """Stop docker containers for the platform. - - Args: - selected_version: The selected version of the platform. - timeout: Command timeout in seconds. - - Returns: - CompletedProcess object with command results. - - Raises: - subprocess.CalledProcessError: If command fails. - subprocess.TimeoutExpired: If command times out. - """ - cmds = _build_docker_compose_base_command(selected_version) - cmds.append("down") - return _run_docker_compose_command(cmds, timeout=timeout) diff --git a/dreadnode/cli/platform/download.py b/dreadnode/cli/platform/download.py index 18f50943..1bb5ea14 100644 --- a/dreadnode/cli/platform/download.py +++ b/dreadnode/cli/platform/download.py @@ -1,144 +1,95 @@ import io -import json import zipfile -from dreadnode.api.models import RegistryImageDetails from dreadnode.cli.api import create_api_client -from dreadnode.cli.platform.constants import SERVICES, VERSIONS_MANIFEST -from dreadnode.cli.platform.schemas import LocalVersionSchema -from dreadnode.cli.platform.utils.env_mgmt import ( +from dreadnode.cli.docker import docker_run +from dreadnode.cli.platform.compose import compose_login, get_compose_args +from dreadnode.cli.platform.constants import PLATFORM_SERVICES, PLATFORM_STORAGE_DIR +from dreadnode.cli.platform.env_mgmt import ( create_default_env_files, ) -from dreadnode.cli.platform.utils.printing import ( - print_error, +from dreadnode.cli.platform.tag import add_tag_arch_suffix +from dreadnode.cli.platform.version import ( + LocalVersion, + VersionConfig, +) +from dreadnode.logging_ import ( + confirm, print_info, print_success, - print_warning, -) -from dreadnode.cli.platform.utils.versions import ( - confirm_with_context, - create_local_latest_tag, - get_available_local_versions, - get_cli_version, - get_local_cache_dir, ) -def _resolve_latest(tag: str) -> str: - """Resolve 'latest' tag to actual version tag from API. - - Args: - tag: Version tag that contains 'latest'. - - Returns: - str: Resolved actual version tag. +def download_platform(tag: str | None = None) -> LocalVersion: """ - api_client = create_api_client() - release_info = api_client.get_platform_releases( - tag, services=[str(service) for service in SERVICES], cli_version=get_cli_version() - ) - return release_info.tag - - -def _create_local_version_file_structure( - tag: str, release_info: RegistryImageDetails -) -> LocalVersionSchema: - """Create local file structure and update manifest for a new version. + Download platform version if not already available locally. Args: - tag: Version tag to create structure for. - release_info: Registry image details from API. + tag: Version tag to download (supports 'latest'). Returns: - LocalVersionSchema: Created local version schema. + LocalVersionSchema: Local version schema for the downloaded/existing version. """ - available_local_versions = get_available_local_versions() - - # Create a new local version schema - local_cache_dir = get_local_cache_dir() - new_version = LocalVersionSchema( - **release_info.model_dump(), - local_path=local_cache_dir / tag, - current=False, - ) - - # Add the new version to the available local versions - available_local_versions.versions.append(new_version) - - # sort the manifest by semver, newest first - available_local_versions.versions.sort(key=lambda v: v.tag, reverse=True) - - # update the manifest file - manifest_path = local_cache_dir / VERSIONS_MANIFEST - with manifest_path.open(encoding="utf-8", mode="w") as f: - json.dump(available_local_versions.model_dump(), f, indent=2) + version_config = VersionConfig.read() + api_client = create_api_client() - print_success(f"Updated versions manifest at {manifest_path} with {new_version.tag}") + # 1 - Resolve the tag - if new_version.local_path.exists(): - print_warning(f"Version {tag} already exists locally.") - if not confirm_with_context("overwrite it?"): - print_error("Aborting download.") - return new_version + tag = tag or "latest" + tag = add_tag_arch_suffix(tag) - # create the directory - new_version.local_path.mkdir(parents=True, exist_ok=True) + if "latest" in tag: + tag = api_client.get_platform_releases( + tag, services=[str(service) for service in PLATFORM_SERVICES] + ).tag - return new_version + # 2 - Check if the version is already available locally + if version_config.versions: + for available_local_version in version_config.versions: + if tag == available_local_version.tag: + print_success(f"[cyan]{tag}[/] is already downloaded.") + return available_local_version -def _download_version_files(tag: str) -> LocalVersionSchema: - """Download platform version files from API and extract locally. + # 3 - Download and check release info - Args: - tag: Version tag to download. + print_info(f"Downloading [cyan]{tag}[/] ...") - Returns: - LocalVersionSchema: Downloaded local version schema. - """ api_client = create_api_client() release_info = api_client.get_platform_releases( - tag, services=[str(service) for service in SERVICES], cli_version=get_cli_version() + tag, services=[str(service) for service in PLATFORM_SERVICES] ) - zip_content = api_client.get_platform_templates(tag) - new_local_version = _create_local_version_file_structure(release_info.tag, release_info) - - with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_file: - zip_file.extractall(new_local_version.local_path) - print_success(f"Downloaded version {tag} to {new_local_version.local_path}") + new_version = LocalVersion( + **release_info.model_dump(), + local_path=PLATFORM_STORAGE_DIR / tag, + current=False, + ) - create_default_env_files(new_local_version) - return new_local_version + if new_version.local_path.exists() and not confirm( + f"{new_version.local_path} exists, overwrite?" + ): + return new_version + version_config.add_version(new_version) -def download_platform(tag: str | None = None) -> LocalVersionSchema: - """Download platform version if not already available locally. + # 4 - Pull the release zip and extract it - Args: - tag: Version tag to download (supports 'latest'). + zip_content = api_client.get_platform_templates(tag) + new_version.local_path.mkdir(parents=True, exist_ok=True) - Returns: - LocalVersionSchema: Local version schema for the downloaded/existing version. - """ - if not tag or tag == "latest": - # all remote images are tagged with architecture - tag = create_local_latest_tag() + with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_file: + zip_file.extractall(new_version.local_path) + print_success(f"Downloaded [cyan]{tag}[/] to {new_version.local_path}") - if "latest" in tag: - tag = _resolve_latest(tag) + create_default_env_files(new_version) - # get what's available - available_local_versions = get_available_local_versions() + # 5 - Pull the images - # if there are versions available - if available_local_versions.versions: - for available_local_version in available_local_versions.versions: - if tag == available_local_version.tag: - print_success( - f"Version {tag} is already downloaded at {available_local_version.local_path}" - ) - return available_local_version + print_info(f"Pulling Docker images for [cyan]{tag}[/] ...") + compose_login(new_version, force=True) + docker_run( + ["compose", *get_compose_args(new_version), "pull"], + ) - print_info(f"Version {tag} is not available locally. Attempting to download it now ...") - return _download_version_files(tag) + return new_version diff --git a/dreadnode/cli/platform/utils/env_mgmt.py b/dreadnode/cli/platform/env_mgmt.py similarity index 96% rename from dreadnode/cli/platform/utils/env_mgmt.py rename to dreadnode/cli/platform/env_mgmt.py index c0303a06..fc935445 100644 --- a/dreadnode/cli/platform/utils/env_mgmt.py +++ b/dreadnode/cli/platform/env_mgmt.py @@ -1,13 +1,13 @@ -import subprocess +import subprocess # nosec import sys import typing as t from pathlib import Path from dreadnode.cli.platform.constants import ( - SERVICES, + PLATFORM_SERVICES, ) -from dreadnode.cli.platform.schemas import LocalVersionSchema -from dreadnode.cli.platform.utils.printing import print_error, print_info +from dreadnode.cli.platform.version import LocalVersion +from dreadnode.logging_ import print_error, print_info LineTypes = t.Literal["variable", "comment", "empty"] @@ -54,7 +54,8 @@ def _parse_env_lines(content: str) -> list[_EnvLine]: def _extract_variables(lines: list[_EnvLine]) -> dict[str, str]: - """Extract just the variables from parsed lines. + """ + Extract just the variables from parsed lines. Args: lines: List of parsed environment file lines. @@ -72,7 +73,8 @@ def _extract_variables(lines: list[_EnvLine]) -> dict[str, str]: def _find_insertion_points( base_lines: list[_EnvLine], remote_lines: list[_EnvLine], new_vars: dict[str, str] ) -> dict[str, int]: - """Find the best insertion points for new variables based on remote file structure. + """ + Find the best insertion points for new variables based on remote file structure. Args: base_lines: Lines from local file. @@ -148,7 +150,8 @@ def _find_insertion_points( def _reconstruct_env_content( # noqa: PLR0912 base_lines: list[_EnvLine], merged_vars: dict[str, str], updated_remote_lines: list[_EnvLine] ) -> str: - """Reconstruct .env content preserving structure from base while applying merged variables. + """ + Reconstruct .env content preserving structure from base while applying merged variables. Args: base_lines: Parsed lines from the local file (for structure). @@ -246,7 +249,7 @@ def _reconstruct_env_content( # noqa: PLR0912 return "\n".join(result_lines) -def create_default_env_files(current_version: LocalVersionSchema) -> None: +def create_default_env_files(current_version: LocalVersion) -> None: """Create default environment files for all services in the current version. Copies sample environment files to actual environment files if they don't exist, @@ -258,7 +261,7 @@ def create_default_env_files(current_version: LocalVersionSchema) -> None: Raises: RuntimeError: If sample environment files are not found or .env file creation fails. """ - for service in SERVICES: + for service in PLATFORM_SERVICES: for image in current_version.images: if image.service == service: env_file_path = current_version.get_env_path_by_service(service) @@ -287,7 +290,7 @@ def open_env_file(filename: Path) -> None: else: cmd = ["xdg-open", filename.as_posix()] try: - subprocess.run(cmd, check=False) # noqa: S603 + subprocess.run(cmd, check=False) # noqa: S603 # nosec print_info("Opened environment file.") except subprocess.CalledProcessError as e: print_error(f"Failed to open environment file: {e}") diff --git a/dreadnode/cli/platform/login.py b/dreadnode/cli/platform/login.py deleted file mode 100644 index ab020199..00000000 --- a/dreadnode/cli/platform/login.py +++ /dev/null @@ -1,20 +0,0 @@ -from dreadnode.cli.platform.docker_ import docker_login -from dreadnode.cli.platform.utils.printing import print_info -from dreadnode.cli.platform.utils.versions import get_current_version - - -def log_into_registries() -> None: - """Log into all Docker registries for the current platform version. - - Iterates through all images in the current version and logs into their - respective registries. If no current version is set, displays an error message. - """ - current_version = get_current_version() - if not current_version: - print_info("There are no registries configured. Run `dreadnode platform start` to start.") - return - registries_attempted = set() - for image in current_version.images: - if image.registry not in registries_attempted: - docker_login(image.registry) - registries_attempted.add(image.registry) diff --git a/dreadnode/cli/platform/schemas.py b/dreadnode/cli/platform/schemas.py deleted file mode 100644 index 1bca04d1..00000000 --- a/dreadnode/cli/platform/schemas.py +++ /dev/null @@ -1,117 +0,0 @@ -from pathlib import Path - -from pydantic import BaseModel, field_serializer - -from dreadnode.api.models import RegistryImageDetails -from dreadnode.cli.platform.constants import API_SERVICE, UI_SERVICE - - -class LocalVersionSchema(RegistryImageDetails): - local_path: Path - current: bool - - def __str__(self) -> str: - return self.tag - - @field_serializer("local_path") - def serialize_path(self, path: Path) -> str: - """Serialize Path object to absolute path string. - - Args: - path: Path object to serialize. - - Returns: - str: Absolute path as string. - """ - return str(path.resolve()) # Convert to absolute path string - - @property - def details(self) -> str: - configured_overrides = ( - "\n".join( - f" - {line}" for line in self.configure_overrides_env_file.read_text().splitlines() - ) - if self.configure_overrides_env_file.exists() - else " (none)" - ) - - return ( - f"Tag: {self.tag}\n" - f"Local Path: {self.local_path}\n" - f"Compose File: {self.compose_file}\n" - f"API Env File: {self.api_env_file}\n" - f"UI Env File: {self.ui_env_file}\n" - f"Configured: \n{configured_overrides}\n" - ) - - @property - def compose_file(self) -> Path: - return self.local_path / "docker-compose.yaml" - - @property - def api_env_file(self) -> Path: - return self.local_path / f".{API_SERVICE}.env" - - @property - def api_example_env_file(self) -> Path: - return self.local_path / f".{API_SERVICE}.example.env" - - @property - def ui_env_file(self) -> Path: - return self.local_path / f".{UI_SERVICE}.env" - - @property - def ui_example_env_file(self) -> Path: - return self.local_path / f".{UI_SERVICE}.example.env" - - @property - def configure_overrides_env_file(self) -> Path: - return self.local_path / ".configure.overrides.env" - - @property - def configure_overrides_compose_file(self) -> Path: - return self.local_path / "docker-compose.configure.overrides.yaml" - - @property - def arg_overrides_env_file(self) -> Path: - return self.local_path / ".arg.overrides.env" - - def get_env_path_by_service(self, service: str) -> Path: - """Get environment file path for a specific service. - - Args: - service: Service name to get env path for. - - Returns: - Path: Path to the service's environment file. - - Raises: - ValueError: If service is not recognized. - """ - if service == API_SERVICE: - return self.api_env_file - if service == UI_SERVICE: - return self.ui_env_file - raise ValueError(f"Unknown service: {service}") - - def get_example_env_path_by_service(self, service: str) -> Path: - """Get example environment file path for a specific service. - - Args: - service: Service name to get example env path for. - - Returns: - Path: Path to the service's example environment file. - - Raises: - ValueError: If service is not recognized. - """ - if service == API_SERVICE: - return self.api_example_env_file - if service == UI_SERVICE: - return self.ui_example_env_file - raise ValueError(f"Unknown service: {service}") - - -class LocalVersionsSchema(BaseModel): - versions: list[LocalVersionSchema] diff --git a/dreadnode/cli/platform/start.py b/dreadnode/cli/platform/start.py deleted file mode 100644 index f643c2d8..00000000 --- a/dreadnode/cli/platform/start.py +++ /dev/null @@ -1,79 +0,0 @@ -from dreadnode.cli.platform.docker_ import ( - DockerError, - docker_login, - docker_requirements_met, - docker_run, - docker_stop, - get_available_local_images, - get_env_var_from_container, - get_required_images, -) -from dreadnode.cli.platform.download import download_platform -from dreadnode.cli.platform.status import platform_is_running -from dreadnode.cli.platform.utils.env_mgmt import write_overrides_env -from dreadnode.cli.platform.utils.printing import print_error, print_info, print_success -from dreadnode.cli.platform.utils.versions import ( - create_local_latest_tag, - get_current_version, - mark_current_version, -) - - -def start_platform(tag: str | None = None, **env_overrides: str) -> None: - """Start the platform with the specified or current version. - - Args: - tag: Optional image tag to use. If not provided, uses the current - version or downloads the latest available version. - """ - if not docker_requirements_met(): - print_error("Docker and Docker Compose must be installed to start the platform.") - return - - if tag: - selected_version = download_platform(tag) - mark_current_version(selected_version) - elif current_version := get_current_version(): - selected_version = current_version - # no need to mark - else: - latest_tag = create_local_latest_tag() - selected_version = download_platform(latest_tag) - mark_current_version(selected_version) - - is_running = platform_is_running(selected_version) - if is_running: - print_info(f"Platform {selected_version.tag} is already running.") - print_info("Use `dreadnode platform stop` to stop it first.") - return - - # check to see if all required images are available locally - required_images = get_required_images(selected_version) - available_images = get_available_local_images() - missing_images = [img for img in required_images if img not in available_images] - if missing_images: - registries_attempted = set() - for image in selected_version.images: - if image.registry not in registries_attempted: - docker_login(image.registry) - registries_attempted.add(image.registry) - - if env_overrides: - write_overrides_env(selected_version.arg_overrides_env_file, **env_overrides) - - print_info(f"Starting platform: {selected_version.tag}") - try: - docker_run(selected_version) - print_success(f"Platform {selected_version.tag} started successfully.") - origin = get_env_var_from_container("dreadnode-ui", "ORIGIN") - if origin: - print_info("You can access the app at the following URLs:") - print_info(f" - {origin}") - else: - print_info(" - Unable to determine the app URL.") - print_info("Please check the container logs for more information.") - except DockerError as e: - print_error(f"Failed to start platform {selected_version.tag}: {e}") - print_info("Stopping any partially started containers...") - docker_stop(selected_version) - print_info("You can check the logs for more details.") diff --git a/dreadnode/cli/platform/status.py b/dreadnode/cli/platform/status.py deleted file mode 100644 index 02b29dfd..00000000 --- a/dreadnode/cli/platform/status.py +++ /dev/null @@ -1,43 +0,0 @@ -from dreadnode.cli.platform.docker_ import docker_ps, get_required_service_names -from dreadnode.cli.platform.schemas import LocalVersionSchema -from dreadnode.cli.platform.utils.printing import print_error, print_success -from dreadnode.cli.platform.utils.versions import get_current_version, get_local_version - - -def platform_is_running(selected_version: LocalVersionSchema) -> bool: - """Check if the platform with the specified or current version is running. - - Args: - tag: Optional image tag to use. If not provided, uses the current - version or downloads the latest available version. - """ - required_services = get_required_service_names(selected_version) - container_details = docker_ps(selected_version) - if not container_details: - return False - for service in required_services: - if service not in [c.name for c in container_details if c.status == "running"]: - return False - return True - - -def platform_status(tag: str | None = None) -> bool: - """Get the status of the platform with the specified or current version. - - Args: - tag: Optional image tag to use. If not provided, uses the current - version or downloads the latest available version. - """ - if tag: - selected_version = get_local_version(tag) - elif current_version := get_current_version(): - selected_version = current_version - else: - print_error("No current platform version is set. Please start or download the platform.") - return False - required_containers_running = platform_is_running(selected_version) - if required_containers_running: - print_success(f"Platform {selected_version.tag} is running.") - else: - print_error(f"Platform {selected_version.tag} is not fully running.") - return required_containers_running diff --git a/dreadnode/cli/platform/stop.py b/dreadnode/cli/platform/stop.py deleted file mode 100644 index 1079cfb2..00000000 --- a/dreadnode/cli/platform/stop.py +++ /dev/null @@ -1,21 +0,0 @@ -from dreadnode.cli.platform.docker_ import docker_stop -from dreadnode.cli.platform.utils.env_mgmt import remove_overrides_env -from dreadnode.cli.platform.utils.printing import print_error, print_success -from dreadnode.cli.platform.utils.versions import ( - get_current_version, -) - - -def stop_platform() -> None: - """Stop the currently running platform. - - Uses the current version's compose file to stop all platform containers - via docker compose down. - """ - current_version = get_current_version() - if not current_version: - print_error("No current version found. Nothing to stop.") - return - remove_overrides_env(current_version.arg_overrides_env_file) - docker_stop(current_version) - print_success("Platform stopped successfully.") diff --git a/dreadnode/cli/platform/tag.py b/dreadnode/cli/platform/tag.py new file mode 100644 index 00000000..5009cb93 --- /dev/null +++ b/dreadnode/cli/platform/tag.py @@ -0,0 +1,40 @@ +import platform + +from packaging.version import Version + +from dreadnode.cli.platform.constants import SUPPORTED_ARCHITECTURES + + +def add_tag_arch_suffix(tag: str) -> str: + """ + Add architecture suffix to a tag if it doesn't already have one. + + Args: + tag: The original tag string. + """ + if any(tag.endswith(f"-{arch}") for arch in SUPPORTED_ARCHITECTURES): + return tag # Tag already has a supported architecture suffix + + arch = platform.machine() + + if arch in ["x86_64", "AMD64"]: + arch = "amd64" + elif arch in ["arm64", "aarch64", "ARM64"]: + arch = "arm64" + else: + raise ValueError(f"Unsupported architecture: {arch}") + + return f"{tag}-{arch}" + + +def tag_to_semver(tag: str) -> Version: + """ + Extract semantic version from a tag by removing architecture suffix. + + Args: + tag: The tag string that may contain an architecture suffix. + + Returns: + str: The tag with any supported architecture suffix removed. + """ + return Version(tag.split("-")[0].removeprefix("v")) diff --git a/dreadnode/cli/platform/upgrade.py b/dreadnode/cli/platform/upgrade.py deleted file mode 100644 index e267bdee..00000000 --- a/dreadnode/cli/platform/upgrade.py +++ /dev/null @@ -1,64 +0,0 @@ -from dreadnode.cli.platform.docker_ import docker_stop -from dreadnode.cli.platform.download import download_platform -from dreadnode.cli.platform.start import start_platform -from dreadnode.cli.platform.utils.printing import print_error, print_info -from dreadnode.cli.platform.utils.versions import ( - confirm_with_context, - create_local_latest_tag, - get_current_version, - get_semver_from_tag, - mark_current_version, - newer_remote_version, -) - - -def upgrade_platform() -> None: - """Upgrade the platform to the latest available version. - - Downloads the latest version, compares it with the current version, - and performs the upgrade if a newer version is available. Optionally - merges configuration files from the current version to the new version. - Stops the current platform and starts the upgraded version. - """ - latest_tag = create_local_latest_tag() - - latest_version = download_platform(latest_tag) - current_local_version = get_current_version() - - if not current_local_version: - print_error( - "No current platform version found. Run `dreadnode platform start` to start the latest version." - ) - return - - current_semver = get_semver_from_tag(current_local_version.tag) - remote_semver = get_semver_from_tag(latest_version.tag) - - if not newer_remote_version(current_semver, remote_semver): - print_info(f"You are using the latest ({current_semver}) version of the platform.") - return - - if not confirm_with_context( - f"Are you sure you want to upgrade from {current_local_version.tag} to {latest_version.tag}?" - ): - print_error("Aborting upgrade.") - return - - # copy the configuration overrides from the current version to the new version - if ( - current_local_version.configure_overrides_compose_file.exists() - and current_local_version.configure_overrides_env_file.exists() - ): - latest_version.configure_overrides_compose_file.write_text( - current_local_version.configure_overrides_compose_file.read_text() - ) - latest_version.configure_overrides_env_file.write_text( - current_local_version.configure_overrides_env_file.read_text() - ) - - print_info(f"Stopping current platform version {current_local_version.tag}...") - docker_stop(current_local_version) - print_info(f"Current platform version {current_local_version.tag} stopped.") - - mark_current_version(latest_version) - start_platform() diff --git a/dreadnode/cli/platform/utils/__init__.py b/dreadnode/cli/platform/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/dreadnode/cli/platform/utils/printing.py b/dreadnode/cli/platform/utils/printing.py deleted file mode 100644 index 2691e44e..00000000 --- a/dreadnode/cli/platform/utils/printing.py +++ /dev/null @@ -1,43 +0,0 @@ -import sys - -import rich - - -def print_success(message: str, prefix: str | None = None) -> None: - """Print success message in green""" - prefix = prefix or "✓" - rich.print(f"[bold green]{prefix}[/] [green]{message}[/]") - - -def print_error(message: str, prefix: str | None = None) -> None: - """Print error message in red""" - prefix = prefix or "✗" - rich.print(f"[bold red]{prefix}[/] [red]{message}[/]", file=sys.stderr) - - -def print_warning(message: str, prefix: str | None = None) -> None: - """Print warning message in yellow""" - prefix = prefix or "⚠" - rich.print(f"[bold yellow]{prefix}[/] [yellow]{message}[/]") - - -def print_info(message: str, prefix: str | None = None) -> None: - """Print info message in blue""" - prefix = prefix or "i" - rich.print(f"[bold blue]{prefix}[/] [blue]{message}[/]") - - -def print_debug(message: str, prefix: str | None = None) -> None: - """Print debug message in dim gray""" - prefix = prefix or "🐛" - rich.print(f"[dim]{prefix}[/] [dim]{message}[/]") - - -def print_heading(message: str) -> None: - """Print section heading""" - rich.print(f"\n[bold underline]{message}[/]\n") - - -def print_muted(message: str) -> None: - """Print muted text""" - rich.print(f"[dim]{message}[/]") diff --git a/dreadnode/cli/platform/utils/versions.py b/dreadnode/cli/platform/utils/versions.py deleted file mode 100644 index 17226cd9..00000000 --- a/dreadnode/cli/platform/utils/versions.py +++ /dev/null @@ -1,184 +0,0 @@ -import importlib.metadata -import json -import platform -from pathlib import Path - -from packaging.version import Version -from rich.prompt import Confirm - -from dreadnode.cli.platform.constants import ( - SUPPORTED_ARCHITECTURES, - VERSIONS_MANIFEST, - SupportedArchitecture, -) -from dreadnode.cli.platform.schemas import LocalVersionSchema, LocalVersionsSchema -from dreadnode.constants import DEFAULT_LOCAL_STORAGE_DIR - - -def _get_local_arch() -> SupportedArchitecture: - """Get the local machine architecture in supported format. - - Returns: - SupportedArchitecture: The architecture as either "amd64" or "arm64". - - Raises: - ValueError: If the local architecture is not supported. - """ - arch = platform.machine() - - if arch in ["x86_64", "AMD64"]: - return "amd64" - if arch in ["arm64", "aarch64", "ARM64"]: - return "arm64" - raise ValueError(f"Unsupported architecture: {arch}") - - -def get_local_cache_dir() -> Path: - """Get the local cache directory path for dreadnode platform files. - - Returns: - Path: Path to the local cache directory (~//platform). - """ - return DEFAULT_LOCAL_STORAGE_DIR / "platform" - - -def get_cli_version() -> str: - """Get the version of the dreadnode CLI package. - - Returns: - str | None: The version string of the dreadnode package, or None if not found. - """ - return importlib.metadata.version("dreadnode") - - -def confirm_with_context(action: str) -> bool: - """Prompt the user for confirmation with a formatted action message. - - Args: - action: The action description to display in the confirmation prompt. - - Returns: - bool: True if the user confirms, False otherwise. Defaults to False. - """ - return Confirm.ask(f"[bold blue]{action}[/bold blue]", default=False) - - -def get_available_local_versions() -> LocalVersionsSchema: - """Get all available local platform versions from the manifest file. - - Creates the manifest file with an empty schema if it doesn't exist. - - Returns: - LocalVersionsSchema: Schema containing all available local platform versions. - """ - try: - local_cache_dir = get_local_cache_dir() - manifest_path = local_cache_dir / VERSIONS_MANIFEST - with manifest_path.open(encoding="utf-8") as f: - versions_manifest_data = json.load(f) - return LocalVersionsSchema(**versions_manifest_data) - except FileNotFoundError: - # create the file - local_cache_dir = get_local_cache_dir() - manifest_path = local_cache_dir / VERSIONS_MANIFEST - manifest_path.parent.mkdir(parents=True, exist_ok=True) - blank_schema = LocalVersionsSchema(versions=[]) - with manifest_path.open(encoding="utf-8", mode="w") as f: - json.dump(blank_schema.model_dump(), f) - return blank_schema - - -def get_current_version() -> LocalVersionSchema | None: - """Get the currently active local platform version. - - Returns: - LocalVersionSchema | None: The current version schema if one is marked as current, - None otherwise. - """ - available_local_versions = get_available_local_versions() - if not available_local_versions.versions: - return None - for version in available_local_versions.versions: - if version.current: - return version - return None - - -def get_local_version(tag: str) -> LocalVersionSchema: - """Get a specific local platform version by its tag. - - Args: - tag: The tag of the version to retrieve. - - Returns: - LocalVersionSchema: The version schema matching the provided tag. - - Raises: - ValueError: If no version with the specified tag is found. - """ - available_local_versions = get_available_local_versions() - for version in available_local_versions.versions: - if version.tag == tag: - return version - raise ValueError(f"No local version found with tag: {tag}") - - -def mark_current_version(current_version: LocalVersionSchema) -> None: - """Mark a specific version as the current active version. - - Updates the versions manifest to mark the specified version as current - and all others as not current. - - Args: - current_version: The version to mark as current. - """ - available_local_versions = get_available_local_versions() - for available_version in available_local_versions.versions: - if available_version.tag == current_version.tag: - available_version.current = True - else: - available_version.current = False - - local_cache_dir = get_local_cache_dir() - manifest_path = local_cache_dir / VERSIONS_MANIFEST - with manifest_path.open(encoding="utf-8", mode="w") as f: - json.dump(available_local_versions.model_dump(), f, indent=2) - - -def create_local_latest_tag() -> str: - """Create a latest tag string for the local architecture. - - Returns: - str: A tag in the format "latest-{arch}" where arch is the local architecture. - """ - arch = _get_local_arch() - return f"latest-{arch}" - - -def get_semver_from_tag(tag: str) -> str: - """Extract semantic version from a tag by removing architecture suffix. - - Args: - tag: The tag string that may contain an architecture suffix. - - Returns: - str: The tag with any supported architecture suffix removed. - """ - for arch in SUPPORTED_ARCHITECTURES: - if arch in tag: - return tag.replace(f"-{arch}", "") - return tag - - -def newer_remote_version(local_version: str, remote_version: str) -> bool: - """Check if the remote version is newer than the local version. - - Args: - local_version: The local version string in semantic version format. - remote_version: The remote version string in semantic version format. - - Returns: - bool: True if the remote version is newer than the local version, False otherwise. - """ - # compare the semvers of two versions to see if the remote is "newer" - return Version(remote_version) > Version(local_version) diff --git a/dreadnode/cli/platform/version.py b/dreadnode/cli/platform/version.py new file mode 100644 index 00000000..605643bc --- /dev/null +++ b/dreadnode/cli/platform/version.py @@ -0,0 +1,216 @@ +from pathlib import Path + +from pydantic import BaseModel, field_serializer + +from dreadnode.api.models import RegistryImageDetails +from dreadnode.cli.platform.constants import ( + API_SERVICE, + UI_SERVICE, + VERSION_CONFIG_PATH, +) +from dreadnode.cli.platform.tag import tag_to_semver + + +class LocalVersion(RegistryImageDetails): + local_path: Path + current: bool + + def __str__(self) -> str: + return self.tag + + @field_serializer("local_path") + def serialize_path(self, path: Path) -> str: + """Serialize Path object to absolute path string. + + Args: + path: Path object to serialize. + + Returns: + str: Absolute path as string. + """ + return str(path.resolve()) # Convert to absolute path string + + @property + def details(self) -> str: + configured_overrides = ( + "\n".join( + f" - {line}" for line in self.configure_overrides_env_file.read_text().splitlines() + ) + if self.configure_overrides_env_file.exists() + else " (none)" + ) + + return ( + f"Tag: {self.tag}\n" + f"Local Path: {self.local_path}\n" + f"Compose File: {self.compose_file}\n" + f"API Env File: {self.api_env_file}\n" + f"UI Env File: {self.ui_env_file}\n" + f"Configured: \n{configured_overrides}\n" + ) + + @property + def compose_file(self) -> Path: + return self.local_path / "docker-compose.yaml" + + @property + def api_env_file(self) -> Path: + return self.local_path / f".{API_SERVICE}.env" + + @property + def api_example_env_file(self) -> Path: + return self.local_path / f".{API_SERVICE}.example.env" + + @property + def ui_env_file(self) -> Path: + return self.local_path / f".{UI_SERVICE}.env" + + @property + def ui_example_env_file(self) -> Path: + return self.local_path / f".{UI_SERVICE}.example.env" + + @property + def configure_overrides_env_file(self) -> Path: + return self.local_path / ".configure.overrides.env" + + @property + def configure_overrides_compose_file(self) -> Path: + return self.local_path / "docker-compose.configure.overrides.yaml" + + @property + def arg_overrides_env_file(self) -> Path: + return self.local_path / ".arg.overrides.env" + + def get_env_path_by_service(self, service: str) -> Path: + """Get environment file path for a specific service. + + Args: + service: Service name to get env path for. + + Returns: + Path: Path to the service's environment file. + + Raises: + ValueError: If service is not recognized. + """ + if service == API_SERVICE: + return self.api_env_file + if service == UI_SERVICE: + return self.ui_env_file + raise ValueError(f"Unknown service: {service}") + + def get_example_env_path_by_service(self, service: str) -> Path: + """Get example environment file path for a specific service. + + Args: + service: Service name to get example env path for. + + Returns: + Path: Path to the service's example environment file. + + Raises: + ValueError: If service is not recognized. + """ + if service == API_SERVICE: + return self.api_example_env_file + if service == UI_SERVICE: + return self.ui_example_env_file + raise ValueError(f"Unknown service: {service}") + + +class VersionConfig(BaseModel): + versions: list[LocalVersion] + + @classmethod + def read(cls) -> "VersionConfig": + """Read the version configuration from the file system or return an empty instance.""" + + if not VERSION_CONFIG_PATH.exists(): + return cls(versions=[]) + + with VERSION_CONFIG_PATH.open("r") as f: + return cls.model_validate_json(f.read()) + + def write(self) -> None: + """Write the versions configuration to the file system.""" + + if not VERSION_CONFIG_PATH.parent.exists(): + VERSION_CONFIG_PATH.parent.mkdir(parents=True) + + with VERSION_CONFIG_PATH.open("w") as f: + f.write(self.model_dump_json()) + + def add_version(self, version: LocalVersion) -> None: + """ + Add a new version to the configuration if it doesn't already exist. + + Args: + version: The LocalVersion instance to add. + """ + if next((v for v in self.versions if v.tag == version.tag), None) is None: + self.versions.append(version) + self.write() + + def get_current_version(self, *, tag: str | None = None) -> LocalVersion | None: + """Get the current active version or a specific version by tag.""" + if tag: + return next((v for v in self.versions if v.tag == tag), None) + + if current := next((v for v in self.versions if v.current), None): + return current + + if latest := self.get_latest_version(): + self.set_current_version(latest) + return latest + + return None + + def get_latest_version(self) -> LocalVersion | None: + """Get the latest version based on semantic versioning.""" + if not self.versions: + return None + sorted_versions = sorted( + self.versions, + key=lambda v: tag_to_semver(v.tag), + reverse=True, + ) + return sorted_versions[0] + + def get_by_tag(self, tag: str) -> LocalVersion: + """ + Get a specific local platform version by its tag. + + Args: + tag: The tag of the version to retrieve. + + Returns: + LocalVersion: The version schema matching the provided tag. + + Raises: + ValueError: If no version with the specified tag is found. + """ + for version in self.versions: + if version.tag == tag: + return version + raise ValueError(f"No local version found with tag: {tag}") + + def set_current_version(self, version: LocalVersion) -> None: + """ + Mark a specific version as the current active version. + + Updates the versions manifest to mark the specified version as current + and all others as not current. + + Args: + version: The version to mark as current. + """ + if next((v for v in self.versions if v.tag == version.tag), None) is None: + self.versions.append(version) + + for available_version in self.versions: + if available_version.tag == version.tag: + available_version.current = True + else: + available_version.current = False + + self.write() diff --git a/dreadnode/cli/profile/cli.py b/dreadnode/cli/profile/cli.py index 9ed1ccf8..ce2ac3a2 100644 --- a/dreadnode/cli/profile/cli.py +++ b/dreadnode/cli/profile/cli.py @@ -1,12 +1,12 @@ import typing as t import cyclopts -import rich from rich import box from rich.prompt import Prompt from rich.table import Table from dreadnode.cli.api import Token +from dreadnode.logging_ import console, print_error, print_info, print_success from dreadnode.user_config import UserConfig from dreadnode.util import time_to @@ -19,7 +19,7 @@ def show() -> None: config = UserConfig.read() if not config.servers: - rich.print(":exclamation: No server profiles are configured") + print_error("No server profiles are configured") return table = Table(box=box.ROUNDED) @@ -44,7 +44,7 @@ def show() -> None: style="bold" if active else None, ) - rich.print(table) + console.print(table) @cli.command() @@ -55,37 +55,39 @@ def switch( config = UserConfig.read() if not config.servers: - rich.print(":exclamation: No server profiles are configured") + print_error("No server profiles are configured") return # If no profile provided, prompt user to choose if profile is None: profiles = list(config.servers.keys()) - rich.print("\nAvailable profiles:") + print_info("Available profiles:") for i, p in enumerate(profiles, 1): active_marker = " (current)" if p == config.active else "" - rich.print(f" {i}. [bold orange_red1]{p}[/]{active_marker}") + print_info(f" {i}. [bold orange_red1]{p}[/]{active_marker}") choice = Prompt.ask( "\nSelect a profile", choices=[str(i) for i in range(1, len(profiles) + 1)] + profiles, show_choices=False, + console=console, ) profile = profiles[int(choice) - 1] if choice.isdigit() else choice if profile not in config.servers: - rich.print(f":exclamation: Profile [bold]{profile}[/] does not exist") + print_error(f"Profile [bold]{profile}[/] does not exist") return config.active = profile config.write() - rich.print(f":laptop_computer: Switched to [bold orange_red1]{profile}[/]") - rich.print(f"|- email: [bold]{config.servers[profile].email}[/]") - rich.print(f"|- username: {config.servers[profile].username}") - rich.print(f"|- url: {config.servers[profile].url}") - rich.print() + print_success( + f"Switched to [bold orange_red1]{profile}[/]\n" + f"|- email: [bold]{config.servers[profile].email}[/]\n" + f"|- username: {config.servers[profile].username}\n" + f"|- url: {config.servers[profile].url}\n" + ) @cli.command() @@ -95,10 +97,10 @@ def forget( """Remove a server profile from the configuration.""" config = UserConfig.read() if profile not in config.servers: - rich.print(f":exclamation: Profile [bold]{profile}[/] does not exist") + print_error(f"Profile [bold]{profile}[/] does not exist") return del config.servers[profile] config.write() - rich.print(f":axe: Forgot about [bold]{profile}[/]") + print_success(f"Forgot about [bold]{profile}[/]") diff --git a/dreadnode/cli/shared.py b/dreadnode/cli/shared.py index 080223b4..2f988698 100644 --- a/dreadnode/cli/shared.py +++ b/dreadnode/cli/shared.py @@ -3,7 +3,7 @@ import cyclopts -from dreadnode.logging_ import LogLevelLiteral, configure_logging +from dreadnode.logging_ import LogLevel, configure_logging @cyclopts.Parameter(name="dn", group="Dreadnode") @@ -19,7 +19,7 @@ class DreadnodeConfig: """Profile name""" console: t.Annotated[bool, cyclopts.Parameter(negative=False)] = False """Show spans in the console""" - log_level: LogLevelLiteral | None = None + log_level: LogLevel | None = None """Console log level""" def apply(self) -> None: diff --git a/dreadnode/logging_.py b/dreadnode/logging_.py index 09b01a87..b0bb0602 100644 --- a/dreadnode/logging_.py +++ b/dreadnode/logging_.py @@ -5,35 +5,32 @@ """ import pathlib -import sys import typing as t +from textwrap import dedent from loguru import logger +from rich.console import Console +from rich.logging import RichHandler +from rich.prompt import Confirm +from rich.theme import Theme -if t.TYPE_CHECKING: - from loguru import Record as LogRecord - -g_configured: bool = False - -LogLevelList = ["trace", "debug", "info", "success", "warning", "error", "critical"] -LogLevelLiteral = t.Literal["trace", "debug", "info", "success", "warning", "error", "critical"] +LogLevel = t.Literal["trace", "debug", "info", "success", "warning", "error", "critical"] """Valid logging levels.""" - -def log_formatter(record: "LogRecord") -> str: - return "".join( - ( - "{time:HH:mm:ss.SSS} | ", - "{extra[prefix]} " if record["extra"].get("prefix") else "", - "{message}\n", - ) - ) +console = Console( + theme=Theme( # rich doesn't include default colors for these + { + "logging.level.success": "green", + "logging.level.trace": "dim blue", + } + ), +) def configure_logging( - log_level: LogLevelLiteral = "info", + log_level: LogLevel = "info", log_file: pathlib.Path | None = None, - log_file_level: LogLevelLiteral = "debug", + log_file_level: LogLevel = "debug", ) -> None: """ Configures common loguru handlers. @@ -44,26 +41,40 @@ def configure_logging( will only be done to the console. log_file_level: The log level for the log file. """ - global g_configured # noqa: PLW0603 - - if g_configured: - return - logger.enable("dreadnode") - logger.level("TRACE", color="", icon="[T]") - logger.level("DEBUG", color="", icon="[_]") - logger.level("INFO", color="", icon="[=]") - logger.level("SUCCESS", color="", icon="[+]") - logger.level("WARNING", color="", icon="[-]") - logger.level("ERROR", color="", icon="[!]") - logger.level("CRITICAL", color="", icon="[x]") - logger.remove() - logger.add(sys.stderr, format=log_formatter, level=log_level.upper()) + logger.add( + RichHandler(console=console, log_time_format="%X"), + format=lambda _: "{message}", + level=log_level.upper(), + ) if log_file is not None: - logger.add(log_file, format=log_formatter, level=log_file_level.upper()) + logger.add(log_file, level=log_file_level.upper()) logger.info(f"Logging to {log_file}") - g_configured = True + +def print_success(message: str) -> None: + console.print(f"[bold green]:heavy_check_mark:[/] {dedent(message)}") + + +def print_error(message: str) -> None: + console.print(f"[bold red]:heavy_multiplication_x:[/] {dedent(message)}") + + +def print_warning(message: str) -> None: + console.print(f"[bold yellow]:warning:[/] {dedent(message)}") + + +def print_info(message: str) -> None: + console.print(dedent(message)) + + +def confirm(action: str) -> bool: + return Confirm.ask( + f"[bold magenta]:left-right_arrow:[/] {action}", + default=False, + case_sensitive=False, + console=console, + ) diff --git a/dreadnode/main.py b/dreadnode/main.py index 54da44de..5293654e 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -10,7 +10,6 @@ import coolname # type: ignore [import-untyped] import logfire -import rich from fsspec.implementations.local import ( # type: ignore [import-untyped] LocalFileSystem, ) @@ -41,6 +40,7 @@ ENV_SERVER_URL, ) from dreadnode.error import AssertionFailedError +from dreadnode.logging_ import console as logging_console from dreadnode.metric import ( Metric, MetricAggMode, @@ -277,11 +277,13 @@ def configure( # Log config information for clarity if self.server or self.token or self.local_dir: destination = self.server or DEFAULT_SERVER_URL or "local storage" - rich.print(f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})") + logging_console.print( + f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})" + ) # Warn the user if the profile didn't resolve elif active_profile and not (self.server or self.token): - rich.print( + logging_console.print( f":exclamation: Dreadnode profile [orange_red1]{active_profile}[/] appears invalid." ) diff --git a/dreadnode/user_config.py b/dreadnode/user_config.py index f1daa806..2131910f 100644 --- a/dreadnode/user_config.py +++ b/dreadnode/user_config.py @@ -1,8 +1,8 @@ -import rich from pydantic import BaseModel from ruamel.yaml import YAML from dreadnode.constants import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH +from dreadnode.logging_ import print_info class ServerConfig(BaseModel): @@ -62,7 +62,7 @@ def write(self) -> None: self._update_active() if not USER_CONFIG_PATH.parent.exists(): - rich.print(f":rocket: Creating config at {USER_CONFIG_PATH.parent}") + print_info(f"Creating config at {USER_CONFIG_PATH.parent}") USER_CONFIG_PATH.parent.mkdir(parents=True) with USER_CONFIG_PATH.open("w") as f: diff --git a/pyproject.toml b/pyproject.toml index 3701e5c2..717db1c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,7 +182,7 @@ skip-magic-trailing-comma = false ] "tests/**/*.py" = [ "INP001", # namespace not required for pytest - "S101", # asserts allowed in tests... "SLF001", # allow access to private members "PLR2004", # magic values + "S1", # security issues in tests are not relevant ] diff --git a/tests/cli/test_config.py b/tests/cli/test_config.py new file mode 100644 index 00000000..a82a9457 --- /dev/null +++ b/tests/cli/test_config.py @@ -0,0 +1,79 @@ +from pathlib import Path + +import pydantic +import pytest + +from dreadnode.user_config import ServerConfig, UserConfig + + +def test_server_config() -> None: + # Test valid server config + config = ServerConfig( + url="https://platform.dreadnode.io", + email="test@example.com", + username="test", + api_key="test123", # pragma: allowlist secret + access_token="token123", + refresh_token="refresh123", + ) + assert config.url == "https://platform.dreadnode.io" + assert config.email == "test@example.com" + assert config.username == "test" + assert config.api_key == "test123" # pragma: allowlist secret + assert config.access_token == "token123" + assert config.refresh_token == "refresh123" + + # Test invalid server config model + with pytest.raises(pydantic.ValidationError): + ServerConfig.model_validate({"invalid": "data"}) + + +def test_user_config(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + # Mock config path to use temporary directory + mock_config_path = tmp_path / "config.yaml" + monkeypatch.setattr("dreadnode.user_config.USER_CONFIG_PATH", mock_config_path) + + # Test empty config + config = UserConfig() + assert config.active is None + assert config.servers == {} + + # Test adding server config + server_config = ServerConfig( + url="https://platform.dreadnode.io", + email="test@example.com", + username="test", + api_key="test123", # pragma: allowlist secret + access_token="token123", + refresh_token="refresh123", + ) + + config.set_server_config(server_config, "default") + assert "default" in config.servers + assert config.get_server_config("default") == server_config + + # Test active profile + config.active = "default" + assert config.get_server_config() == server_config + + # Test writing and reading config + config.write() + assert mock_config_path.exists() + + loaded_config = UserConfig.read() + assert loaded_config.active == "default" + assert loaded_config.servers["default"] == server_config + + # Test invalid profile access + with pytest.raises(RuntimeError): + config.get_server_config("nonexistent") + + # Test auto-setting active profile + config.active = None + config._update_active() + assert config.active == "default" + + # Test empty config edge case + empty_config = UserConfig() + empty_config._update_active() + assert empty_config.active is None diff --git a/tests/cli/test_docker.py b/tests/cli/test_docker.py new file mode 100644 index 00000000..eea20641 --- /dev/null +++ b/tests/cli/test_docker.py @@ -0,0 +1,161 @@ +import pytest + +from dreadnode.cli.docker import DockerImage + + +@pytest.mark.parametrize( + ("input_str", "expected_repo", "expected_tag"), + [ + ("ubuntu", "library/ubuntu", "latest"), + ("postgres", "library/postgres", "latest"), + ("alpine:3.18", "library/alpine", "3.18"), + ("redis:7", "library/redis", "7"), + ], +) +def test_docker_image_official_images( + input_str: str, expected_repo: str, expected_tag: str +) -> None: + """Tests parsing of official Docker Hub images (e.g., 'ubuntu').""" + img = DockerImage(input_str) + assert img.registry is None # Should be None if not provided + assert img.repository == expected_repo + assert img.tag == expected_tag + assert img.digest is None + + +@pytest.mark.parametrize( + ("input_str", "expected_repo", "expected_tag"), + [ + ("dreadnode/image", "dreadnode/image", "latest"), + ("bitnami/postgresql", "bitnami/postgresql", "latest"), + ("minio/minio:RELEASE.2023-03-20T20-16-18Z", "minio/minio", "RELEASE.2023-03-20T20-16-18Z"), + ], +) +def test_docker_image_namespaced_images( + input_str: str, expected_repo: str, expected_tag: str +) -> None: + """Tests parsing of namespaced Docker Hub images (e.g., 'dreadnode/image').""" + img = DockerImage(input_str) + assert img.registry is None # Should be None if not provided + assert img.repository == expected_repo + assert img.tag == expected_tag + assert img.digest is None + + +@pytest.mark.parametrize( + ("input_str", "expected_registry", "expected_repo", "expected_tag"), + [ + ("gcr.io/google-samples/hello-app:1.0", "gcr.io", "google-samples/hello-app", "1.0"), + ("ghcr.io/owner/image:tag", "ghcr.io", "owner/image", "tag"), + ("localhost:5000/my-app", "localhost:5000", "my-app", "latest"), + ("my.registry:1234/a/b/c:v1", "my.registry:1234", "a/b/c", "v1"), + ], +) +def test_docker_image_with_custom_registry( + input_str: str, expected_registry: str, expected_repo: str, expected_tag: str +) -> None: + """Tests parsing of images with a full registry hostname.""" + img = DockerImage(input_str) + assert img.registry == expected_registry + assert img.repository == expected_repo + assert img.tag == expected_tag + assert img.digest is None + + +@pytest.mark.parametrize( + ("input_str", "expected_repo", "expected_tag", "expected_digest"), + [ + ("ubuntu@sha256:abc", "library/ubuntu", None, "sha256:abc"), + ("dreadnode/image@sha256:123", "dreadnode/image", None, "sha256:123"), + ("gcr.io/app/image@sha256:xyz", "app/image", None, "sha256:xyz"), + ("ubuntu:22.04@sha256:456", "library/ubuntu", "22.04", "sha256:456"), + ], +) +def test_docker_image_with_digest( + input_str: str, expected_repo: str, expected_tag: str | None, expected_digest: str +) -> None: + """Tests parsing of images with a digest.""" + img = DockerImage(input_str) + assert img.repository == expected_repo + assert img.tag == expected_tag + assert img.digest == expected_digest + + +def test_docker_image_self_format_and_normalization() -> None: + """Test that DockerImage can handle its own normalized string outputs.""" + img1 = DockerImage("ubuntu") + assert str(img1) == "library/ubuntu:latest" + + img2 = DockerImage(str(img1)) + assert img2.registry is None + assert img2.repository == "library/ubuntu" + assert img2.tag == "latest" + assert img1 == img2 + + +def test_docker_image_whitespace_handling() -> None: + """Test that leading/trailing whitespace is properly stripped.""" + img = DockerImage(" ubuntu:22.04 \n") + assert img.repository == "library/ubuntu" + assert img.tag == "22.04" + assert str(img) == "library/ubuntu:22.04" + + +@pytest.mark.parametrize( + "case", + [ + "", # Empty string + " ", # Just whitespace + "@sha256:123", # Just a digest + ], +) +def test_docker_image_invalid_formats(case: str) -> None: + """Test that invalid formats raise ValueError.""" + with pytest.raises(ValueError, match="Invalid Docker image format"): + DockerImage(case) + + +def test_docker_image_string_methods_inheritance() -> None: + """Test that string methods from the str parent class work as expected.""" + img = DockerImage("ubuntu:22.04") + assert str(img) == "library/ubuntu:22.04" + assert img.upper() == "LIBRARY/UBUNTU:22.04" + assert img.startswith("library") + + +def test_docker_image_comparisons() -> None: + """Test comparison operations.""" + img1 = DockerImage("ubuntu") + img2 = DockerImage("library/ubuntu:latest") + img3 = DockerImage("postgres") + img4 = DockerImage("docker.io/library/ubuntu:latest") + + assert img1 == img2 + assert img1 != img3 + assert img1 != img4 # These are now different, as one specifies a registry + assert img1 == "library/ubuntu:latest" + + +def test_docker_image_with_method() -> None: + """Tests the with_() method for creating modified copies.""" + original = DockerImage("gcr.io/project/image:1.0") + + # Change registry + with_new_registry = original.with_(registry="ghcr.io") + assert str(with_new_registry) == "ghcr.io/project/image:1.0" + assert with_new_registry.registry == "ghcr.io" + + # Remove registry by setting to None + with_no_registry = original.with_(registry=None) + assert str(with_no_registry) == "project/image:1.0" + assert with_no_registry.registry is None + + # Change tag and remove digest + with_digest = DockerImage("gcr.io/project/image:1.0@sha256:abc") + with_new_tag = with_digest.with_(tag="2.0-beta", digest=None) + assert str(with_new_tag) == "gcr.io/project/image:2.0-beta" + assert with_new_tag.tag == "2.0-beta" + assert with_new_tag.digest is None + + # Ensure original object is not mutated + assert str(original) == "gcr.io/project/image:1.0" diff --git a/tests/cli/test_github.py b/tests/cli/test_github.py new file mode 100644 index 00000000..2bb20c74 --- /dev/null +++ b/tests/cli/test_github.py @@ -0,0 +1,203 @@ +import pytest + +from dreadnode.cli.github import GithubRepo + + +def test_github_repo_simple_format() -> None: + repo = GithubRepo("owner/repo") + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == "main" + assert str(repo) == "owner/repo@main" + + +def test_github_repo_simple_format_with_ref() -> None: + repo = GithubRepo("owner/repo/tree/develop") + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == "develop" + assert str(repo) == "owner/repo@develop" + + +@pytest.mark.parametrize( + "case", + [ + "https://github.com/owner/repo", + "http://github.com/owner/repo", + "https://github.com/owner/repo.git", + ], +) +def test_github_repo_https_url(case: str) -> None: + repo = GithubRepo(case) + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == "main" + assert str(repo) == "owner/repo@main" + + +@pytest.mark.parametrize( + ("case", "expected_ref"), + [ + ("https://github.com/owner/repo/tree/feature/custom-branch", "feature/custom-branch"), + ("https://github.com/owner/repo/blob/feature/custom-branch", "feature/custom-branch"), + ], +) +def test_github_repo_https_url_with_ref(case: str, expected_ref: str) -> None: + repo = GithubRepo(case) + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == expected_ref + assert str(repo) == f"owner/repo@{expected_ref}" + + +@pytest.mark.parametrize( + "case", + [ + "git@github.com:owner/repo", + "git@github.com:owner/repo.git", + ], +) +def test_github_repo_ssh_url(case: str) -> None: + repo = GithubRepo(case) + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == "main" + assert str(repo) == "owner/repo@main" + + +@pytest.mark.parametrize( + ("case", "expected_ref"), + [ + ("https://raw.githubusercontent.com/owner/repo/main", "main"), + ("https://raw.githubusercontent.com/owner/repo/feature-branch", "feature-branch"), + ("https://raw.githubusercontent.com/owner/repo/feature/branch", "feature/branch"), + ], +) +def test_github_repo_raw_githubusercontent(case: str, expected_ref: str) -> None: + repo = GithubRepo(case) + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == expected_ref + assert str(repo) == f"owner/repo@{expected_ref}" + + +@pytest.mark.parametrize( + ("input_str", "expected_ref"), + [ + ("owner/repo/tree/feature/custom", "feature/custom"), + ("owner/repo/releases/tag/v1.0.0", "v1.0.0"), + ], +) +def test_github_repo_ref_handling(input_str: str, expected_ref: str) -> None: + """Test handling of different reference formats""" + repo = GithubRepo(input_str) + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert repo.ref == expected_ref + assert repo.zip_url == f"https://github.com/owner/repo/zipball/{expected_ref}" + + +@pytest.mark.parametrize( + "case", + [ + "owner/repo.js", + "https://github.com/owner/repo.js", + "git@github.com:owner/repo.js.git", + ], +) +def test_github_repo_with_dots(case: str) -> None: + """Test repositories with dots in names""" + repo = GithubRepo(case) + assert repo.namespace == "owner" + assert repo.repo == "repo.js" + assert str(repo) == "owner/repo.js@main" + + +@pytest.mark.parametrize( + "case", + [ + "owner-name/repo-name", + "https://github.com/owner-name/repo-name", + "git@github.com:owner-name/repo-name.git", + ], +) +def test_github_repo_with_dashes(case: str) -> None: + """Test repositories with dashes in names""" + repo = GithubRepo(case) + assert repo.namespace == "owner-name" + assert repo.repo == "repo-name" + assert str(repo) == "owner-name/repo-name@main" + + +@pytest.mark.parametrize( + "case", + [ + " owner/repo ", + "\nowner/repo\n", + "\towner/repo\t", + ], +) +def test_github_repo_whitespace_handling(case: str) -> None: + """Test that whitespace is properly stripped""" + repo = GithubRepo(case) + assert repo.namespace == "owner" + assert repo.repo == "repo" + assert str(repo) == "owner/repo@main" + + +@pytest.mark.parametrize( + "case", + [ + "", # Empty string + "owner", # Missing repo + "owner/", # Missing repo + "/repo", # Missing owner + "owner/repo/extra", # Too many parts + "http://gitlab.com/owner/repo", # Wrong domain + "git@gitlab.com:owner/repo.git", # Wrong domain + ], +) +def test_github_repo_invalid_formats(case: str) -> None: + """Test that invalid formats raise ValueError""" + with pytest.raises(ValueError, match="Invalid GitHub repository format"): + GithubRepo(case) + + +def test_github_repo_string_methods_inheritance() -> None: + """Test that string methods work as expected""" + repo = GithubRepo("owner/repo") + assert repo.upper() == "OWNER/REPO@MAIN" + assert repo.split("/") == ["owner", "repo@main"] + assert repo.split("@") == ["owner/repo", "main"] + assert repo.replace("owner", "newowner") == "newowner/repo@main" + assert len(repo) == len("owner/repo@main") + + +def test_github_repo_comparisons() -> None: + """Test comparison operations""" + repo1 = GithubRepo("owner/repo") + repo2 = GithubRepo("owner/repo") + repo3 = GithubRepo("different/repo") + + assert repo1 == repo2 + assert repo1 != repo3 + assert repo1 == "owner/repo@main" + assert repo1 == "owner/repo@main" + assert repo1 in ["owner/repo@main", "other/repo@main"] + + +def test_github_repo_self_format() -> None: + """Test that GithubRepo can handle its own string representations""" + # Test basic format + repo1 = GithubRepo("owner/repo@main") + assert repo1.namespace == "owner" + assert repo1.repo == "repo" + assert repo1.ref == "main" + assert str(repo1) == "owner/repo@main" + + # Test creating from existing repo string + repo2 = GithubRepo(str(repo1)) + assert repo2.namespace == "owner" + assert repo2.repo == "repo" + assert repo2.ref == "main" + assert str(repo2) == str(repo1) diff --git a/uv.lock b/uv.lock index 94a607e9..ef020a36 100644 --- a/uv.lock +++ b/uv.lock @@ -681,7 +681,7 @@ wheels = [ [[package]] name = "dreadnode" -version = "1.14.0" +version = "1.14.1" source = { editable = "." } dependencies = [ { name = "coolname" }, From ebb6ba5c8276337674eb4469c88659e3c6248cbb Mon Sep 17 00:00:00 2001 From: monoxgas Date: Thu, 2 Oct 2025 02:02:06 -0600 Subject: [PATCH 2/3] co-pilot suggestions --- dreadnode/cli/platform/compose.py | 4 +--- dreadnode/cli/platform/tag.py | 2 +- tests/cli/test_github.py | 2 -- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/dreadnode/cli/platform/compose.py b/dreadnode/cli/platform/compose.py index 37813822..d9176c46 100644 --- a/dreadnode/cli/platform/compose.py +++ b/dreadnode/cli/platform/compose.py @@ -66,9 +66,7 @@ def get_required_services(version: LocalVersion) -> list[str]: contents: dict[str, object] = yaml.safe_load(version.compose_file.read_text()) services = t.cast("dict[str, object]", contents.get("services", {}) or {}) return [ - name - for name, cfg in services.items() - if isinstance(cfg, dict) and "x-required" in cfg and cfg["x-required"] is True + name for name, cfg in services.items() if isinstance(cfg, dict) and cfg.get("x-required") ] diff --git a/dreadnode/cli/platform/tag.py b/dreadnode/cli/platform/tag.py index 5009cb93..e93ed723 100644 --- a/dreadnode/cli/platform/tag.py +++ b/dreadnode/cli/platform/tag.py @@ -35,6 +35,6 @@ def tag_to_semver(tag: str) -> Version: tag: The tag string that may contain an architecture suffix. Returns: - str: The tag with any supported architecture suffix removed. + A packaging Version object representing the semantic version. """ return Version(tag.split("-")[0].removeprefix("v")) diff --git a/tests/cli/test_github.py b/tests/cli/test_github.py index 2bb20c74..6fc680ba 100644 --- a/tests/cli/test_github.py +++ b/tests/cli/test_github.py @@ -182,8 +182,6 @@ def test_github_repo_comparisons() -> None: assert repo1 == repo2 assert repo1 != repo3 assert repo1 == "owner/repo@main" - assert repo1 == "owner/repo@main" - assert repo1 in ["owner/repo@main", "other/repo@main"] def test_github_repo_self_format() -> None: From ae070dbf790d14c0644300da617d8fcc0fc672eb Mon Sep 17 00:00:00 2001 From: Brian Greunke Date: Fri, 3 Oct 2025 09:33:51 -0500 Subject: [PATCH 3/3] fix: fixed platform status checker --- dreadnode/cli/platform/compose.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dreadnode/cli/platform/compose.py b/dreadnode/cli/platform/compose.py index d9176c46..f0ebd992 100644 --- a/dreadnode/cli/platform/compose.py +++ b/dreadnode/cli/platform/compose.py @@ -66,7 +66,9 @@ def get_required_services(version: LocalVersion) -> list[str]: contents: dict[str, object] = yaml.safe_load(version.compose_file.read_text()) services = t.cast("dict[str, object]", contents.get("services", {}) or {}) return [ - name for name, cfg in services.items() if isinstance(cfg, dict) and cfg.get("x-required") + cfg.get("container_name") + for name, cfg in services.items() + if isinstance(cfg, dict) and cfg.get("x-required") ] @@ -166,7 +168,7 @@ def platform_is_running(version: LocalVersion) -> bool: return False for service in get_required_services(version): - if service not in [c.name for c in containers if c.status == "running"]: + if service not in [c.name for c in containers if c.state == "running"]: return False return True