diff --git a/docs/sdk/api.mdx b/docs/sdk/api.mdx
index 2512b872..c8ab6597 100644
--- a/docs/sdk/api.mdx
+++ b/docs/sdk/api.mdx
@@ -505,37 +505,6 @@ def export_timeseries(
```
-
-
-### get\_container\_registry\_credentials
-
-```python
-get_container_registry_credentials() -> (
- ContainerRegistryCredentials
-)
-```
-
-Retrieves container registry credentials for Docker image access.
-
-**Returns:**
-
-* `ContainerRegistryCredentials`
- –The container registry credentials object.
-
-
-```python
-def get_container_registry_credentials(self) -> ContainerRegistryCredentials:
- """
- Retrieves container registry credentials for Docker image access.
-
- Returns:
- The container registry credentials object.
- """
- response = self.request("POST", "/platform/registry-token")
- return ContainerRegistryCredentials(**response.json())
-```
-
-
### get\_device\_codes
@@ -577,13 +546,44 @@ def get_github_access_token(self, repos: list[str]) -> GithubTokenResponse:
```
+
+
+### get\_platform\_registry\_credentials
+
+```python
+get_platform_registry_credentials() -> (
+ ContainerRegistryCredentials
+)
+```
+
+Retrieves container registry credentials for Docker image access.
+
+**Returns:**
+
+* `ContainerRegistryCredentials`
+ –The container registry credentials object.
+
+
+```python
+def get_platform_registry_credentials(self) -> ContainerRegistryCredentials:
+ """
+ Retrieves container registry credentials for Docker image access.
+
+ Returns:
+ The container registry credentials object.
+ """
+ response = self.request("POST", "/platform/registry-token")
+ return ContainerRegistryCredentials(**response.json())
+```
+
+
### get\_platform\_releases
```python
get_platform_releases(
- tag: str, services: list[str], cli_version: str | None
+ tag: str, services: list[str]
) -> RegistryImageDetails
```
@@ -596,35 +596,17 @@ Resolves the platform releases for the current project.
```python
-def get_platform_releases(
- self, tag: str, services: list[str], cli_version: str | None
-) -> RegistryImageDetails:
+def get_platform_releases(self, tag: str, services: list[str]) -> RegistryImageDetails:
"""
Resolves the platform releases for the current project.
Returns:
The resolved platform releases as a ResolveReleasesResponse object.
"""
- payload = {
- "tag": tag,
- "services": services,
- "cli_version": cli_version,
- }
- try:
- response = self.request("POST", "/platform/get-releases", json_data=payload)
-
- except RuntimeError as e:
- if "403" in str(e):
- raise RuntimeError("You do not have access to platform releases.") from e
-
- if "404" in str(e):
- if "Image not found" in str(e):
- raise RuntimeError("Image not found") from e
+ from dreadnode.version import VERSION
- raise RuntimeError(
- f"Failed to get platform releases: {e}. The feature is likely disabled on this server"
- ) from e
- raise
+ payload = {"tag": tag, "services": services, "cli_version": VERSION}
+ response = self.request("POST", "/platform/get-releases", json_data=payload)
return RegistryImageDetails(**response.json())
```
diff --git a/docs/sdk/main.mdx b/docs/sdk/main.mdx
index 40f6c8c6..be900b5e 100644
--- a/docs/sdk/main.mdx
+++ b/docs/sdk/main.mdx
@@ -329,11 +329,13 @@ def configure(
# Log config information for clarity
if self.server or self.token or self.local_dir:
destination = self.server or DEFAULT_SERVER_URL or "local storage"
- rich.print(f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})")
+ logging_console.print(
+ f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})"
+ )
# Warn the user if the profile didn't resolve
elif active_profile and not (self.server or self.token):
- rich.print(
+ logging_console.print(
f":exclamation: Dreadnode profile [orange_red1]{active_profile}[/] appears invalid."
)
diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py
index ad10de48..181423ac 100644
--- a/dreadnode/api/client.py
+++ b/dreadnode/api/client.py
@@ -726,7 +726,7 @@ def get_user_data_credentials(self) -> UserDataCredentials:
# Container registry access
- def get_container_registry_credentials(self) -> ContainerRegistryCredentials:
+ def get_platform_registry_credentials(self) -> ContainerRegistryCredentials:
"""
Retrieves container registry credentials for Docker image access.
@@ -736,35 +736,17 @@ def get_container_registry_credentials(self) -> ContainerRegistryCredentials:
response = self.request("POST", "/platform/registry-token")
return ContainerRegistryCredentials(**response.json())
- def get_platform_releases(
- self, tag: str, services: list[str], cli_version: str | None
- ) -> RegistryImageDetails:
+ def get_platform_releases(self, tag: str, services: list[str]) -> RegistryImageDetails:
"""
Resolves the platform releases for the current project.
Returns:
The resolved platform releases as a ResolveReleasesResponse object.
"""
- payload = {
- "tag": tag,
- "services": services,
- "cli_version": cli_version,
- }
- try:
- response = self.request("POST", "/platform/get-releases", json_data=payload)
-
- except RuntimeError as e:
- if "403" in str(e):
- raise RuntimeError("You do not have access to platform releases.") from e
-
- if "404" in str(e):
- if "Image not found" in str(e):
- raise RuntimeError("Image not found") from e
+ from dreadnode.version import VERSION
- raise RuntimeError(
- f"Failed to get platform releases: {e}. The feature is likely disabled on this server"
- ) from e
- raise
+ payload = {"tag": tag, "services": services, "cli_version": VERSION}
+ response = self.request("POST", "/platform/get-releases", json_data=payload)
return RegistryImageDetails(**response.json())
def get_platform_templates(self, tag: str) -> bytes:
diff --git a/dreadnode/cli/docker.py b/dreadnode/cli/docker.py
new file mode 100644
index 00000000..83d06c9f
--- /dev/null
+++ b/dreadnode/cli/docker.py
@@ -0,0 +1,387 @@
+import re
+import subprocess # nosec
+import typing as t
+
+from pydantic import AliasChoices, BaseModel, Field
+
+from dreadnode.common_types import UNSET, Unset
+from dreadnode.constants import (
+ DEFAULT_DOCKER_REGISTRY_IMAGE_TAG,
+ DEFAULT_DOCKER_REGISTRY_LOCAL_PORT,
+ DEFAULT_DOCKER_REGISTRY_SUBDOMAIN,
+ DEFAULT_PLATFORM_BASE_DOMAIN,
+)
+from dreadnode.logging_ import print_info
+from dreadnode.user_config import ServerConfig
+
+DockerContainerState = t.Literal[
+ "running", "exited", "paused", "restarting", "removing", "created", "dead"
+]
+
+
+class DockerError(Exception):
+ pass
+
+
+class DockerImage(str): # noqa: SLOT000
+ """
+ A string subclass that normalizes and parses various Docker image string formats.
+
+ Supported formats:
+ - ubuntu
+ - ubuntu:22.04
+ - library/ubuntu:22.04
+ - docker.io/library/ubuntu:22.04
+ - myregistry:5000/my/image:latest
+ - myregistry:5000/my/image@sha256:f6e42a...
+ - dreadnode/image (correctly parsed as a Docker Hub image)
+ """
+
+ registry: str | None
+ repository: str
+ tag: str | None
+ digest: str | None
+
+ def __new__(cls, value: str, *_: t.Any, **__: t.Any) -> "DockerImage":
+ value = value.strip()
+ if not value:
+ raise ValueError("Invalid Docker image format: input cannot be empty")
+
+ # 1. Separate digest from the rest
+ digest: str | None = None
+ if "@" in value:
+ value, digest = value.split("@", 1)
+
+ # 2. Separate tag from the repository path
+ tag: str | None = None
+ repo_path = value
+ # A tag is present if there's a colon that is NOT part of a port number in a registry hostname
+ if ":" in value:
+ possible_repo, possible_tag = value.rsplit(":", 1)
+ # If the part before the colon contains a slash, or no slash at all, it's a tag.
+ # This correctly handles "ubuntu:22.04" and "gcr.io/my/image:tag" but not "localhost:5000/image".
+ if "/" in possible_repo or "/" not in value:
+ repo_path, tag = possible_repo, possible_tag
+
+ if not repo_path:
+ raise ValueError("Invalid Docker image format: missing repository name")
+
+ # 3. Determine the registry and the final repository name
+ registry: str | None = None
+ repository = repo_path
+
+ if "/" not in repo_path:
+ # Case 1: An official image like "ubuntu".
+ repository = f"library/{repo_path}"
+ else:
+ # Case 2: A namespaced path. It could be "dreadnode/image"
+ # or "gcr.io/google-containers/busybox".
+ first_part = repo_path.split("/", 1)[0]
+ if "." in first_part or ":" in first_part:
+ # If the first part has a "." or ":", it's a registry hostname.
+ registry = first_part
+ repository = repo_path.split("/", 1)[1]
+
+ # 4. Default to 'latest' tag if no tag or digest is provided
+ if not tag and not digest:
+ tag = "latest"
+
+ # 5. Construct the full, normalized string for the object's value
+ full_image_str = repository
+ if registry:
+ full_image_str = f"{registry}/{repository}"
+
+ if tag:
+ full_image_str += f":{tag}"
+ if digest:
+ full_image_str += f"@{digest}"
+
+ obj = super().__new__(cls, full_image_str)
+ obj.registry = registry
+ obj.repository = repository
+ obj.tag = tag
+ obj.digest = digest
+
+ return obj
+
+ def __repr__(self) -> str:
+ parts = [f"repository='{self.repository}'"]
+ if self.registry:
+ parts.append(f"registry='{self.registry}'")
+ if self.tag:
+ parts.append(f"tag='{self.tag}'")
+ if self.digest:
+ parts.append(f"digest='{self.digest}'")
+ return f"{self.__class__.__name__}({', '.join(parts)})"
+
+ def with_(
+ self,
+ *,
+ repository: str | Unset = UNSET,
+ registry: str | None | Unset = UNSET,
+ tag: str | None | Unset = UNSET,
+ digest: str | None | Unset = UNSET,
+ ) -> "DockerImage":
+ """
+ Create a new DockerImage instance with updated elements.
+ """
+ new_registry = registry if not isinstance(registry, Unset) else self.registry
+ new_repository = repository if not isinstance(repository, Unset) else self.repository
+ new_tag = tag if not isinstance(tag, Unset) else self.tag
+ new_digest = digest if not isinstance(digest, Unset) else self.digest
+
+ new_image = new_repository
+ if new_registry:
+ new_image = f"{new_registry}/{new_repository}"
+
+ if new_tag:
+ new_image += f":{new_tag}"
+ if new_digest:
+ new_image += f"@{new_digest}"
+
+ return DockerImage(new_image)
+
+
+class DockerContainer(BaseModel):
+ id: str = Field(..., alias="ID")
+ name: str = Field(..., validation_alias=AliasChoices("Name", "Names"))
+ exit_code: int = Field(-1, alias="ExitCode")
+ state: DockerContainerState = Field(..., alias="State")
+ status: str = Field(..., alias="Status")
+ raw_ports: str = Field(..., alias="Ports")
+ image: str = Field(..., alias="Image")
+ command: str = Field(..., alias="Command")
+
+ @property
+ def is_running(self) -> bool:
+ return self.state == "running"
+
+ @property
+ def ports(self) -> list[tuple[int, int]]:
+ """
+ Parse the raw_ports string into a list of tuples mapping host ports to container ports.
+ """
+ ports = []
+ for mapping in self.raw_ports.split(","):
+ host_part, container_part = mapping.split("->")
+ host_port = int(host_part.split(":")[-1])
+ container_port = int(container_part.split("/")[0])
+ ports.append((host_port, container_port))
+ return ports
+
+
+def docker_run(
+ args: list[str],
+ *,
+ timeout: int = 300,
+ stdin_input: str | None = None,
+ capture_output: bool = False,
+) -> subprocess.CompletedProcess[str]:
+ """
+ Execute a docker command with common error handling and configuration.
+
+ Args:
+ args: Additional arguments for the docker command.
+ timeout: Command timeout in seconds.
+ stdin_input: Optional input string to pass to the command's stdin.
+ capture_output: Whether to capture the command's output.
+
+ Returns:
+ CompletedProcess object with command results.
+
+ Raises:
+ subprocess.CalledProcessError: If command fails.
+ subprocess.TimeoutExpired: If command times out.
+ FileNotFoundError: If docker/docker-compose not found.
+ """
+ try:
+ result = subprocess.run( # noqa: S603 # nosec
+ ["docker", *args], # noqa: S607
+ check=True,
+ text=True,
+ timeout=timeout,
+ encoding="utf-8",
+ errors="replace",
+ input=stdin_input,
+ capture_output=capture_output,
+ )
+ if not capture_output:
+ print_info("") # Some padding after command output
+
+ except subprocess.CalledProcessError as e:
+ command_str = " ".join(e.cmd)
+ error = f"Docker command failed: {command_str}"
+ if e.stderr:
+ error += f"\n\n{e.stderr}"
+ raise DockerError(error) from e
+
+ except subprocess.TimeoutExpired as e:
+ command_str = " ".join(e.cmd)
+ raise DockerError(f"Docker command timed out after {timeout} seconds: {command_str}") from e
+
+ except FileNotFoundError as e:
+ raise DockerError(
+ "`docker` not found, please ensure it is installed and in your PATH."
+ ) from e
+
+ return result
+
+
+def get_available_local_images() -> list[DockerImage]:
+ """
+ Get the list of available Docker images on the local system.
+
+ Returns:
+ List of available Docker image names.
+ """
+ result = docker_run(
+ ["images", "--format", "{{.Repository}}:{{.Tag}}@{{.Digest}}"],
+ capture_output=True,
+ timeout=30,
+ )
+ return [DockerImage(line) for line in result.stdout.splitlines() if line.strip()]
+
+
+def get_env_var_from_container(container_name: str, var_name: str) -> str | None:
+ """
+ Get the specified environment variable from the container and return
+ its value.
+
+ Args:
+ container_name: Name of the container to inspect.
+ var_name: Name of the environment variable to retrieve.
+
+ Returns:
+ str | None: Value of the environment variable, or None if not found.
+ """
+ result = docker_run(
+ ["inspect", "-f", "{{range .Config.Env}}{{println .}}{{end}}", container_name],
+ capture_output=True,
+ timeout=30,
+ )
+ for line in result.stdout.splitlines():
+ if line.startswith(f"{var_name.upper()}="):
+ return line.split("=", 1)[1]
+ return None
+
+
+def docker_login(registry: str, username: str, password: str) -> None:
+ """
+ Log into a Docker registry.
+
+ Args:
+ registry: Registry hostname to log into.
+ username: Username for the registry.
+ password: Password for the registry.
+ """
+ docker_run(
+ ["login", registry, "--username", username, "--password-stdin"],
+ stdin_input=password,
+ capture_output=True,
+ timeout=60,
+ )
+
+
+def docker_ps() -> list[DockerContainer]:
+ """
+ List and parse running containers using `docker ps`.
+
+ Returns:
+ A list of DockerPSResult objects.
+ """
+ result = docker_run(
+ ["ps", "--format", "json"],
+ capture_output=True,
+ )
+ return [
+ DockerContainer.model_validate_json(line)
+ for line in result.stdout.splitlines()
+ if line.strip()
+ ]
+
+
+def docker_compose_ps(args: list[str] | None = None) -> list[DockerContainer]:
+ """
+ List and parse running containers using `docker compose ps`.
+
+ This mirrors:
+ docker compose [*args] ps --format json
+
+ Args:
+ args: Additional docker compose arguments.
+
+ Returns:
+ A list of DockerPSResult objects.
+ """
+ result = docker_run(
+ ["compose", *(args or []), "ps", "--format", "json"],
+ capture_output=True,
+ )
+ return [
+ DockerContainer.model_validate_json(line)
+ for line in result.stdout.splitlines()
+ if line.strip()
+ ]
+
+
+def docker_tag(image: str | DockerImage, new_tag: str) -> None:
+ """
+ Tag a Docker image with a new tag.
+
+ Args:
+ image: The name of the image to tag.
+ new_tag: The new tag to apply to the image.
+ """
+ docker_run(
+ ["tag", str(image), new_tag],
+ capture_output=True,
+ timeout=60,
+ )
+
+
+def get_local_registry_port() -> int:
+ for container in docker_ps():
+ if DEFAULT_DOCKER_REGISTRY_IMAGE_TAG in container.image and container.ports:
+ # return the first mapped port
+ return container.ports[0][0]
+
+ # fallback to the default port
+ return DEFAULT_DOCKER_REGISTRY_LOCAL_PORT
+
+
+def get_registry(config: ServerConfig) -> str:
+ # localhost is a special case
+ if "localhost" in config.url or "127.0.0.1" in config.url:
+ return f"localhost:{get_local_registry_port()}"
+
+ prefix = ""
+ if "staging-" in config.url:
+ prefix = "staging-"
+ elif "dev-" in config.url:
+ prefix = "dev-"
+
+ return f"{prefix}{DEFAULT_DOCKER_REGISTRY_SUBDOMAIN}.{DEFAULT_PLATFORM_BASE_DOMAIN}"
+
+
+def docker_pull(image: str | DockerImage) -> None:
+ """
+ Pull a Docker image.
+
+ Args:
+ image: The name of the image to pull.
+ """
+ docker_run(["pull", image])
+
+
+def clean_username(name: str) -> str:
+ """
+ Sanitizes an agent or user name to be used in a Docker repository URI.
+ """
+ # convert to lowercase
+ name = name.lower()
+ # replace non-alphanumeric characters with hyphens
+ name = re.sub(r"[^\w\s-]", "", name)
+ # replace one or more whitespace characters with a single hyphen
+ name = re.sub(r"[-\s]+", "-", name)
+ # remove leading or trailing hyphens
+ return name.strip("-")
diff --git a/dreadnode/cli/github.py b/dreadnode/cli/github.py
index 3ca6ea55..0731d03c 100644
--- a/dreadnode/cli/github.py
+++ b/dreadnode/cli/github.py
@@ -6,9 +6,9 @@
import zipfile
import httpx
-import rich
from rich.prompt import Prompt
+from dreadnode.logging_ import confirm, console, print_info, print_warning
from dreadnode.user_config import UserConfig, find_dreadnode_saas_profiles, is_dreadnode_saas_server
@@ -174,7 +174,7 @@ def download_and_unzip_archive(url: str, *, headers: dict[str, str] | None = Non
temp_dir = pathlib.Path(tempfile.mkdtemp())
local_zip_path = temp_dir / "archive.zip"
- rich.print(f":arrow_double_down: Downloading {url} ...")
+ print_info(f"Downloading {url} ...")
# download to temporary file
with httpx.stream("GET", url, follow_redirects=True, verify=True, headers=headers) as response:
@@ -216,58 +216,50 @@ def validate_server_for_clone(user_config: UserConfig, current_profile: str | No
return current_profile or user_config.active_profile_name
# Current server is not a Dreadnode SaaS server - warn user
- rich.print()
- rich.print(":warning: [yellow]Warning: Current server is not a Dreadnode SaaS server[/]")
- rich.print(f" Current server: [cyan]{current_server}[/]")
- rich.print(f" Current profile: [cyan]{current_profile or user_config.active_profile_name}[/]")
- rich.print()
- rich.print("Git clone for private dreadnode repositories requires a Dreadnode SaaS server")
- rich.print("(ending with '.dreadnode.io') for authentication to work properly.")
- rich.print()
+ print_warning(
+ f"Current server is not a Dreadnode SaaS server\n"
+ f" Current server: [cyan]{current_server}[/]\n"
+ f" Current profile: [cyan]{current_profile or user_config.active_profile_name}[/]\n\n"
+ "Git clone for private dreadnode repositories requires a Dreadnode SaaS server\n"
+ "(ending with '.dreadnode.io') for authentication to work properly.\n"
+ )
# Check if there are any SaaS profiles available
saas_profiles = find_dreadnode_saas_profiles(user_config)
if saas_profiles:
- rich.print("Available Dreadnode SaaS profiles:")
+ print_info("Available Dreadnode SaaS profiles:")
for profile in saas_profiles:
server_url = user_config.servers[profile].url
- rich.print(f" - [green]{profile}[/] ({server_url})")
- rich.print()
+ console.print(f" - [bold]{profile}[/] ({server_url})")
choices = ["continue", "switch", "cancel"]
choice = Prompt.ask(
- "Choose an option", choices=choices, default="cancel", show_choices=True
+ "\nChoose an option", choices=choices, default="cancel", show_choices=True
)
if choice == "continue":
- rich.print(
- ":warning: [yellow]Continuing with current server - private repository access may fail[/]"
- )
+ print_warning("Continuing with current server - private repository access may fail")
return current_profile or user_config.active_profile_name
if choice == "cancel":
- rich.print("Cancelled.")
return None
if choice == "switch":
# Let user pick a profile
profile_choice = Prompt.ask(
- "Select profile to use", choices=saas_profiles, default=saas_profiles[0]
- )
- rich.print(
- f":arrows_counterclockwise: Using profile '[green]{profile_choice}[/]' for this operation"
+ "\nSelect profile to use",
+ choices=saas_profiles,
+ default=saas_profiles[0],
+ console=console,
)
+ print_info(f"Using profile '[cyan]{profile_choice}[/]' for this operation")
return profile_choice
else:
# No SaaS profiles available
- choice = Prompt.ask("Continue anyway?", choices=["y", "n"], default="n")
-
- if choice == "y":
- rich.print(
- ":warning: [yellow]Continuing with current server - private repository access may fail[/]"
+ if not confirm("Continue anyway?"):
+ print_info(
+ "Cancelled. Use [bold]dreadnode login --server https://platform.dreadnode.io[/] to add a SaaS profile."
)
- return current_profile or user_config.active_profile_name
- rich.print(
- "Cancelled. Use [bold]dreadnode login --server https://platform.dreadnode.io[/] to add a SaaS profile."
- )
+ print_warning("Continuing with current server - private repository access may fail")
+ return current_profile or user_config.active_profile_name
return None
diff --git a/dreadnode/cli/main.py b/dreadnode/cli/main.py
index f1f98a05..b0a05d75 100644
--- a/dreadnode/cli/main.py
+++ b/dreadnode/cli/main.py
@@ -8,13 +8,12 @@
import webbrowser
import cyclopts
-import rich
from rich.panel import Panel
-from rich.prompt import Prompt
from dreadnode.api.client import ApiClient
from dreadnode.cli.agent import cli as agent_cli
from dreadnode.cli.api import create_api_client
+from dreadnode.cli.docker import DockerImage, docker_login, docker_pull, docker_tag, get_registry
from dreadnode.cli.eval import cli as eval_cli
from dreadnode.cli.github import (
GithubRepo,
@@ -25,6 +24,7 @@
from dreadnode.cli.profile import cli as profile_cli
from dreadnode.cli.study import cli as study_cli
from dreadnode.constants import DEBUG, PLATFORM_BASE_URL
+from dreadnode.logging_ import confirm, console, print_info, print_success
from dreadnode.user_config import ServerConfig, UserConfig
cli = cyclopts.App(help="Interact with Dreadnode platforms", version_flags=[], help_on_error=True)
@@ -43,30 +43,29 @@ def meta(
*tokens: t.Annotated[str, cyclopts.Parameter(show=False, allow_leading_hyphen=True)],
) -> None:
try:
- rich.print()
+ console.print()
cli(tokens)
except Exception as e:
if DEBUG:
raise
-
- rich.print()
- rich.print(Panel(str(e), title="Error", title_align="left", border_style="red"))
+ console.print()
+ console.print(Panel(str(e), title="Error", title_align="left", border_style="red"))
sys.exit(1)
@cli.command(group="Auth")
def login(
*,
- server: t.Annotated[
- str | None,
- cyclopts.Parameter(name=["--server", "-s"], help="URL of the server"),
- ] = None,
- profile: t.Annotated[
- str | None,
- cyclopts.Parameter(name=["--profile", "-p"], help="Profile alias to assign / update"),
- ] = None,
+ server: t.Annotated[str | None, cyclopts.Parameter(name=["--server", "-s"])] = None,
+ profile: t.Annotated[str | None, cyclopts.Parameter(name=["--profile", "-p"])] = None,
) -> None:
- """Authenticate to a Dreadnode platform server and save the profile."""
+ """
+ Authenticate to a Dreadnode platform server and save the profile.
+
+ Args:
+ server: The server URL to authenticate against.
+ profile: The profile name to save the server configuration under.
+ """
if not server:
server = PLATFORM_BASE_URL
with contextlib.suppress(Exception):
@@ -76,7 +75,7 @@ def login(
# create client with no auth data
client = ApiClient(base_url=server)
- rich.print(":laptop_computer: Requesting device code ...")
+ print_info("Requesting device code ...")
# request user and device codes
codes = client.get_device_codes()
@@ -85,16 +84,15 @@ def login(
verification_url = client.url_for_user_code(codes.user_code)
verification_url_base = verification_url.split("?")[0]
- rich.print()
- rich.print(
- f"""\
-Attempting to automatically open the authorization page in your default browser.
-If the browser does not open or you wish to use a different device, open the following URL:
+ print_info(
+ f"""
+ Attempting to automatically open the authorization page in your default browser.
+ If the browser does not open or you wish to use a different device, open the following URL:
-:link: [bold]{verification_url_base}[/]
+ [bold]{verification_url_base}[/]
-Then enter the code: [bold]{codes.user_code}[/]
-"""
+ Then enter the code: [bold]{codes.user_code}[/]
+ """
)
webbrowser.open(verification_url)
@@ -111,7 +109,8 @@ def login(
)
user = client.get_user()
- UserConfig.read().set_server_config(
+ user_config = UserConfig.read()
+ user_config.set_server_config(
ServerConfig(
url=server,
access_token=tokens.access_token,
@@ -121,9 +120,11 @@ def login(
api_key=user.api_key.key,
),
profile,
- ).write()
+ )
+ user_config.active = profile
+ user_config.write()
- rich.print(f":white_check_mark: Authenticated as {user.email_address} ({user.username})")
+ print_success(f"Authenticated as {user.email_address} ({user.username})")
@cli.command(group="Auth")
@@ -142,32 +143,64 @@ def refresh() -> None:
user_config.set_server_config(server_config).write()
- rich.print(
- f":white_check_mark: Refreshed '[bold]{user_config.active}[/bold]' ([magenta]{user.email_address}[/] / [cyan]{user.username}[/])"
+ print_success(
+ f"Refreshed '[bold]{user_config.active}[/bold]' ([magenta]{user.email_address}[/] / [cyan]{user.username}[/])"
)
+@cli.command()
+def pull(image: str) -> None:
+ """
+ Pull a capability image from the dreadnode registry.
+
+ Args:
+ image: The name of the image to pull (e.g. dreadnode/agent:latest).
+ """
+ user_config = UserConfig.read()
+ if not user_config.active_profile_name:
+ raise RuntimeError("No server profile is set, use [bold]dreadnode login[/] to authenticate")
+
+ server_config = user_config.get_server_config()
+
+ docker_image = DockerImage(image)
+ tag_as: str | None = None
+ if docker_image.repository.startswith("dreadnode/") and not docker_image.registry:
+ docker_image = docker_image.with_(
+ registry=get_registry(user_config.get_server_config()),
+ )
+ tag_as = image
+
+ if docker_image.registry and docker_image.registry != "docker.io":
+ print_info(f"Authenticating to [bold]{docker_image.registry}[/] ...")
+ docker_login(docker_image.registry, server_config.username, server_config.api_key)
+
+ print_info(f"Pulling image [bold]{docker_image}[/] ...")
+ docker_pull(docker_image)
+
+ if tag_as:
+ docker_tag(docker_image, tag_as)
+
+
@cli.command()
def clone(
- repo: t.Annotated[str, cyclopts.Parameter(help="Repository name or URL")],
- target: t.Annotated[
- pathlib.Path | None,
- cyclopts.Parameter(help="The target directory"),
- ] = None,
+ repo: str,
+ target: pathlib.Path | None = None,
) -> None:
- """Clone a GitHub repository to a local directory"""
+ """
+ Clone a GitHub repository to a local directory
+ Args:
+ repo: Repository name or URL.
+ target: The target directory.
+ """
github_repo = GithubRepo(repo)
# Check if the target directory exists
target = target or pathlib.Path(github_repo.repo)
if target.exists():
- if (
- Prompt.ask(f":axe: Overwrite {target.absolute()}?", choices=["y", "n"], default="n")
- == "n"
- ):
+ if not confirm(f"{target.absolute()} exists, overwrite?"):
return
- rich.print()
+ console.print()
shutil.rmtree(target)
# Check if the repo is accessible
@@ -187,7 +220,7 @@ def clone(
github_access_token = create_api_client(profile=profile_to_use).get_github_access_token(
[github_repo.repo]
)
- rich.print(":key: Accessed private repository")
+ print_info("Accessed private repository")
temp_dir = download_and_unzip_archive(
github_repo.api_zip_url,
headers={"Authorization": f"Bearer {github_access_token.token}"},
@@ -204,8 +237,7 @@ def clone(
shutil.move(temp_dir, target)
- rich.print()
- rich.print(f":tada: Cloned [b]{repo}[/] to [b]{target.absolute()}[/]")
+ print_success(f"Cloned [b]{repo}[/] to [b]{target.absolute()}[/]")
@cli.command(help="Show versions and exit.", group="Meta")
@@ -215,6 +247,6 @@ def version() -> None:
os_name = platform.system()
arch = platform.machine()
- rich.print(f"Platform: {os_name} ({arch})")
- rich.print(f"Python: {python_version}")
- rich.print(f"Dreadnode: {version}")
+ print_info(f"Platform: {os_name} ({arch})")
+ print_info(f"Python: {python_version}")
+ print_info(f"Dreadnode: {version}")
diff --git a/dreadnode/cli/platform/cli.py b/dreadnode/cli/platform/cli.py
index 89879b2f..2e261d7c 100644
--- a/dreadnode/cli/platform/cli.py
+++ b/dreadnode/cli/platform/cli.py
@@ -2,100 +2,199 @@
import cyclopts
-from dreadnode.cli.platform.configure import configure_platform, list_configurations
+from dreadnode.cli.docker import (
+ DockerError,
+ get_env_var_from_container,
+)
+from dreadnode.cli.platform.compose import (
+ build_compose_override_file,
+ compose_down,
+ compose_login,
+ compose_logs,
+ compose_up,
+ platform_is_running,
+)
+from dreadnode.cli.platform.constants import PLATFORM_SERVICES
from dreadnode.cli.platform.download import download_platform
-from dreadnode.cli.platform.login import log_into_registries
-from dreadnode.cli.platform.start import start_platform
-from dreadnode.cli.platform.status import platform_status
-from dreadnode.cli.platform.stop import stop_platform
-from dreadnode.cli.platform.upgrade import upgrade_platform
-from dreadnode.cli.platform.utils.printing import print_info
-from dreadnode.cli.platform.utils.versions import get_current_version
+from dreadnode.cli.platform.env_mgmt import (
+ build_env_file,
+ read_env_file,
+ remove_overrides_env,
+ write_overrides_env,
+)
+from dreadnode.cli.platform.tag import tag_to_semver
+from dreadnode.cli.platform.version import VersionConfig
+from dreadnode.logging_ import confirm, print_error, print_info, print_success, print_warning
cli = cyclopts.App("platform", help="Run and manage the platform.", help_flags=[])
@cli.command()
-def start(
- tag: t.Annotated[
- str | None, cyclopts.Parameter(help="Optional image tag to use when starting the platform.")
- ] = None,
- **env_overrides: t.Annotated[
- str,
- cyclopts.Parameter(
- help="Environment variable overrides. Use --key value format. "
- "Examples: --proxy-host myproxy.local"
- ),
- ],
-) -> None:
- """Start the platform. Optionally, provide a tagged version to start.
+def start(tag: str | None = None, **env_overrides: str) -> None:
+ """
+ Start the platform.
Args:
- tag: Optional image tag to use when starting the platform.
- **env_overrides: Key-value pairs to override environment variables in the
+ tag: Image tag to use when starting the platform.
+ env_overrides: Key-value pairs to override environment variables in the
platform's .env file. e.g `--proxy-host myproxy.local`
"""
- start_platform(tag=tag, **env_overrides)
+ version_config = VersionConfig.read()
+ version = version_config.get_current_version(tag=tag) or download_platform(tag)
+ version_config.set_current_version(version)
+
+ if platform_is_running(version):
+ print_info(f"Platform {version.tag} is already running.")
+ print_info("Use `dreadnode platform stop` to stop it first.")
+ return
+
+ compose_login(version)
+
+ if env_overrides:
+ write_overrides_env(version.arg_overrides_env_file, **env_overrides)
+
+ print_info(f"Starting platform [cyan]{version.tag}[/] ...")
+ try:
+ compose_up(version)
+ print_success("Platform started.")
+ origin = get_env_var_from_container("dreadnode-ui", "ORIGIN")
+ if origin:
+ print_info("You can access the app at the following URLs:")
+ print_info(f" - {origin}")
+ else:
+ print_info(" - Unable to determine the app URL.")
+ print_info("Please check the container logs for more information.")
+ except DockerError as e:
+ compose_logs(version, tail=10)
+ print_error(str(e))
@cli.command(name=["stop", "down"])
-def stop() -> None:
- """Stop the running platform."""
- stop_platform()
+def stop(*, remove_volumes: t.Annotated[bool, cyclopts.Parameter(negative=False)] = False) -> None:
+ """
+ Stop the running platform.
+
+ Args:
+ remove_volumes: Also remove Docker volumes associated with the platform.
+ """
+ version = VersionConfig.read().get_current_version()
+ if not version:
+ print_error("No current version found. Nothing to stop.")
+ return
+
+ remove_overrides_env(version.arg_overrides_env_file)
+ compose_down(version, remove_volumes=remove_volumes)
+ print_success("Platform stopped.")
@cli.command()
-def download(
- tag: t.Annotated[
- str | None, cyclopts.Parameter(help="Optional image tag to use when starting the platform.")
- ] = None,
-) -> None:
- """Download platform files for a specific tag.
+def logs(tail: int = 100) -> None:
+ """
+ View the platform logs.
Args:
- tag: Optional image tag to download.
+ tail: Number of lines to show from the end of the logs for each service.
+ """
+ version = VersionConfig.read().get_current_version()
+ if not version:
+ print_error("No current version found. Nothing to show logs for.")
+ return
+
+ compose_logs(version, tail=tail)
+
+
+@cli.command()
+def download(tag: str | None = None) -> None:
+ """
+ Download platform files for a specific tag.
+
+ Args:
+ tag: Specific version tag to download.
"""
download_platform(tag=tag)
@cli.command()
def upgrade() -> None:
- """Upgrade the platform to the latest version."""
- upgrade_platform()
+ """
+ Upgrade the platform to the latest available version.
+
+ Downloads the latest version, compares it with the current version,
+ and performs the upgrade if a newer version is available. Optionally
+ merges configuration files from the current version to the new version.
+ Stops the current platform and starts the upgraded version.
+ """
+ version_config = VersionConfig.read()
+ current_version = version_config.get_current_version()
+ if not current_version:
+ start()
+ return
+
+ latest_version = download_platform()
+
+ current_semver = tag_to_semver(current_version.tag)
+ remote_semver = tag_to_semver(latest_version.tag)
+
+ if current_semver >= remote_semver:
+ print_info(f"You are using the latest ({current_semver}) version of the platform.")
+ return
+
+ if not confirm(
+ f"Upgrade from [cyan]{current_version.tag}[/] -> [magenta]{latest_version.tag}[/]?"
+ ):
+ return
+
+ version_config.set_current_version(latest_version)
+
+ # copy the configuration overrides from the current version to the new version
+ if (
+ current_version.configure_overrides_compose_file.exists()
+ and current_version.configure_overrides_env_file.exists()
+ ):
+ latest_version.configure_overrides_compose_file.write_text(
+ current_version.configure_overrides_compose_file.read_text()
+ )
+ latest_version.configure_overrides_env_file.write_text(
+ current_version.configure_overrides_env_file.read_text()
+ )
+
+ print_info("Stopping current platform ...")
+ compose_down(current_version)
+ compose_up(latest_version)
+ print_success(f"Platform upgraded to version [magenta]{latest_version.tag}[/].")
@cli.command()
def refresh_registry_auth() -> None:
- """Refresh container registry credentials for platform access.
+ """
+ Refresh container registry credentials for platform access.
Used for out of band Docker management.
"""
- log_into_registries()
+ current_version = VersionConfig.read().get_current_version()
+ if not current_version:
+ print_info("There are no registries configured. Run `dreadnode platform start` to start.")
+ return
+
+ compose_login(current_version, force=True)
@cli.command()
def configure(
- *args: t.Annotated[
- str,
- cyclopts.Parameter(
- help="Key-value pairs to set. Must be provided in pairs (key value key value ...). ",
- ),
- ],
- tag: t.Annotated[
- str | None, cyclopts.Parameter(help="Optional image tag to use when starting the platform.")
- ] = None,
+ *args: str,
+ tag: str | None = None,
list: t.Annotated[
bool,
- cyclopts.Parameter(
- ["--list", "-l"], help="List current configuration without making changes."
- ),
+ cyclopts.Parameter(["--list", "-l"], negative=False),
] = False,
unset: t.Annotated[
bool,
- cyclopts.Parameter(["--unset", "-u"], help="Remove the specified configuration."),
+ cyclopts.Parameter(["--unset", "-u"], negative=False),
] = False,
) -> None:
- """Configure the platform for a specific service.
+ """
+ Configure the platform for a specific service.
+
Configurations will take effect the next time the platform is started and are persisted.
Usage: platform configure KEY VALUE [KEY2 VALUE2 ...]
@@ -104,14 +203,28 @@ def configure(
platform configure proxy-host myproxy.local api-port 8080
Args:
- *args: Key-value pairs to set. Must be provided in pairs (key value key value ...).
+ args: Key-value pairs to set. Must be provided in pairs (key value key value ...).
tag: Optional image tag to use when starting the platform.
+ list: List current configuration without making changes.
+ unset: Remove the specified configuration.
"""
+ current_version = VersionConfig.read().get_current_version(tag=tag)
+ if not current_version:
+ print_info("No current platform version is set. Please start or download the platform.")
+ return
+
if list:
- if args:
- raise ValueError("The --list option does not take any positional arguments.")
- list_configurations()
+ overrides_env_file = current_version.configure_overrides_env_file
+ if not overrides_env_file.exists():
+ print_info("No configuration overrides found.")
+ return
+
+ print_info(f"Configuration overrides from {overrides_env_file}:")
+ env_vars = read_env_file(overrides_env_file)
+ for key, value in env_vars.items():
+ print_info(f" - {key}={value}")
return
+
# Parse positional arguments into key-value pairs
if not unset and len(args) % 2 != 0:
raise ValueError(
@@ -122,43 +235,57 @@ def configure(
env_overrides = {}
for i in range(0, len(args), 2):
key = args[i]
- value = args[i + 1] if not unset else None
- env_overrides[key] = value
+ env_overrides[key] = args[i + 1] if not unset else None
- configure_platform(tag=tag, **env_overrides)
+ if not env_overrides:
+ print_warning("No configuration changes specified.")
+ return
+
+ print_info("Setting environment overrides ...")
+ build_compose_override_file(PLATFORM_SERVICES, current_version)
+ build_env_file(current_version.configure_overrides_env_file, **env_overrides)
+ print_info(
+ f"Configuration written to {current_version.local_path}.\n\n"
+ "These will take effect the next time the platform is started. "
+ "You can modify or remove them at any time."
+ )
@cli.command()
def version(
verbose: t.Annotated[ # noqa: FBT002
- bool,
- cyclopts.Parameter(
- ["--verbose", "-v"], help="Display detailed information for the version."
- ),
+ bool, cyclopts.Parameter(["--verbose", "-v"])
] = False,
) -> None:
- """Show the current platform version."""
- version = get_current_version()
- if version:
- if verbose:
- print_info(version.details)
- else:
- print_info(f"Current platform version: {version!s}")
+ """
+ Show the current platform version.
- else:
+ Args:
+ verbose: Display detailed information about the version.
+ """
+ version_config = VersionConfig.read()
+ version = version_config.get_current_version()
+ if version is None:
print_info("No current platform version is set.")
+ return
+ print_info(f"Current version: [cyan]{version}[/]")
+ if verbose:
+ print_info(version.details)
-@cli.command()
-def status(
- tag: t.Annotated[
- str | None, cyclopts.Parameter(help="Optional image tag to use when checking status.")
- ] = None,
-) -> None:
- """Get the status of the platform with the specified or current version.
- Args:
- tag: Optional image tag to use. If not provided, uses the current
- version or downloads the latest available version.
+@cli.command()
+def status() -> None:
"""
- platform_status(tag=tag)
+ Get the current status of the platform.
+ """
+ version_config = VersionConfig.read()
+ version = version_config.get_current_version()
+ if version is None:
+ print_error("No current platform version is set. Please start or download the platform.")
+ return
+
+ if platform_is_running(version):
+ print_success(f"Platform {version.tag} is running.")
+ else:
+ print_error(f"Platform {version.tag} is not fully running.")
diff --git a/dreadnode/cli/platform/compose.py b/dreadnode/cli/platform/compose.py
new file mode 100644
index 00000000..f0ebd992
--- /dev/null
+++ b/dreadnode/cli/platform/compose.py
@@ -0,0 +1,212 @@
+import typing as t
+
+import yaml
+from yaml import safe_dump
+
+from dreadnode.cli.api import create_api_client
+from dreadnode.cli.docker import (
+ DockerImage,
+ docker_compose_ps,
+ docker_login,
+ docker_run,
+ get_available_local_images,
+)
+from dreadnode.cli.platform.constants import PlatformService
+from dreadnode.cli.platform.env_mgmt import read_env_file
+from dreadnode.cli.platform.version import LocalVersion
+from dreadnode.logging_ import print_info
+
+
+def build_compose_override_file(
+ services: list[PlatformService],
+ version: LocalVersion,
+) -> None:
+ # build a yaml docker compose override file
+ # that only includes the service being configured
+ # and has an `env_file` attribute for the service
+ override = {
+ "services": {
+ f"{service}": {"env_file": [version.configure_overrides_env_file.as_posix()]}
+ for service in services
+ },
+ }
+
+ with version.configure_overrides_compose_file.open("w") as f:
+ safe_dump(override, f, sort_keys=False)
+
+
+def get_required_images(version: LocalVersion) -> list[DockerImage]:
+ """
+ Get the list of required Docker images for the specified platform version.
+
+ Args:
+ version: The selected version of the platform.
+
+ Returns:
+ list[str]: List of required Docker image names.
+ """
+ result = docker_run(
+ ["compose", *get_compose_args(version), "config", "--images"],
+ timeout=120,
+ capture_output=True,
+ )
+
+ if result.returncode != 0:
+ return []
+
+ return [DockerImage(line) for line in result.stdout.splitlines() if line.strip()]
+
+
+def get_required_services(version: LocalVersion) -> list[str]:
+ """Get the list of required services from the docker-compose file.
+
+ Returns:
+ list[str]: List of required service names.
+ """
+ contents: dict[str, object] = yaml.safe_load(version.compose_file.read_text())
+ services = t.cast("dict[str, object]", contents.get("services", {}) or {})
+ return [
+ cfg.get("container_name")
+ for name, cfg in services.items()
+ if isinstance(cfg, dict) and cfg.get("x-required")
+ ]
+
+
+def get_profiles_to_enable(version: LocalVersion) -> list[str]:
+ """
+ Get the list of profiles to enable based on environment variables.
+
+ If any of the `x-profile-disabled-vars` are set in the environment,
+ the profile will be disabled.
+
+ E.g.
+
+ services:
+ myservice:
+ image: myimage:latest
+ profiles:
+ - myprofile
+ x-profile-override-vars:
+ - MY_SERVICE_HOST
+
+ If MY_SERVICE_HOST is set in the environment, the `myprofile` profile
+ will NOT be excluded from the docker compose --profile cmd.
+ """
+
+ contents: dict[str, object] = yaml.safe_load(version.compose_file.read_text())
+ services = t.cast("dict[str, object]", contents.get("services", {}) or {})
+ profiles_to_enable: set[str] = set()
+ for service in services.values():
+ if not isinstance(service, dict):
+ continue
+
+ profiles = service.get("profiles", [])
+ if not profiles or not isinstance(profiles, list):
+ continue
+
+ x_override_vars = service.get("x-profile-override-vars", [])
+ if not x_override_vars or not isinstance(x_override_vars, list):
+ profiles_to_enable.update(profiles)
+ continue
+
+ configuration_file = version.configure_overrides_env_file
+ overrides_file = version.arg_overrides_env_file
+
+ env_vars: dict[str, str] = {}
+ if configuration_file.exists():
+ env_vars.update(read_env_file(configuration_file))
+ if overrides_file.exists():
+ env_vars.update(read_env_file(overrides_file))
+
+ # check if any of the override vars are set in the env
+ if any(var in env_vars for var in x_override_vars):
+ continue # skip enabling this profile
+
+ profiles_to_enable.update(profiles)
+
+ return list(profiles_to_enable)
+
+
+def get_compose_args(
+ version: LocalVersion, *, project_name: str = "dreadnode-platform"
+) -> list[str]:
+ command = ["-p", project_name]
+ compose_files = [version.compose_file]
+ env_files = [version.api_env_file, version.ui_env_file]
+
+ if (
+ version.configure_overrides_compose_file.exists()
+ and version.configure_overrides_env_file.exists()
+ ):
+ compose_files.append(version.configure_overrides_compose_file)
+ env_files.append(version.configure_overrides_env_file)
+
+ for compose_file in compose_files:
+ command.extend(["-f", compose_file.as_posix()])
+
+ for profile in get_profiles_to_enable(version):
+ command.extend(["--profile", profile])
+
+ if version.arg_overrides_env_file.exists():
+ env_files.append(version.arg_overrides_env_file)
+
+ for env_file in env_files:
+ command.extend(["--env-file", env_file.as_posix()])
+
+ return command
+
+
+def platform_is_running(version: LocalVersion) -> bool:
+ """
+ Check if the platform with the specified or current version is running.
+
+ Args:
+ version: LocalVersionSchema of the platform to check.
+ """
+ containers = docker_compose_ps(get_compose_args(version))
+ if not containers:
+ return False
+
+ for service in get_required_services(version):
+ if service not in [c.name for c in containers if c.state == "running"]:
+ return False
+
+ return True
+
+
+def compose_up(version: LocalVersion) -> None:
+ docker_run(
+ ["compose", *get_compose_args(version), "up", "-d"],
+ )
+
+
+def compose_down(version: LocalVersion, *, remove_volumes: bool = False) -> None:
+ args = ["compose", *get_compose_args(version), "down"]
+ if remove_volumes:
+ args.append("--volumes")
+ docker_run(args)
+
+
+def compose_logs(version: LocalVersion, *, tail: int = 100) -> None:
+ docker_run(["compose", *get_compose_args(version), "logs", "--tail", str(tail)])
+
+
+def compose_login(version: LocalVersion, *, force: bool = False) -> None:
+ # check to see if all required images are available locally
+ required_images = get_required_images(version)
+ available_images = get_available_local_images()
+ missing_images = [img for img in required_images if img not in available_images]
+ if not missing_images and not force:
+ return
+
+ client = create_api_client()
+ registry_credentials = client.get_platform_registry_credentials()
+
+ registries_attempted = set()
+ for image in version.images:
+ if image.registry not in registries_attempted:
+ print_info(f"Logging in to Docker registry: {image.registry} ...")
+ docker_login(
+ image.registry, registry_credentials.username, registry_credentials.password
+ )
+ registries_attempted.add(image.registry)
diff --git a/dreadnode/cli/platform/configure.py b/dreadnode/cli/platform/configure.py
deleted file mode 100644
index 171c9666..00000000
--- a/dreadnode/cli/platform/configure.py
+++ /dev/null
@@ -1,47 +0,0 @@
-from dreadnode.cli.platform.constants import SERVICES
-from dreadnode.cli.platform.docker_ import build_docker_compose_override_file
-from dreadnode.cli.platform.utils.env_mgmt import build_env_file, read_env_file
-from dreadnode.cli.platform.utils.printing import print_info
-from dreadnode.cli.platform.utils.versions import get_current_version, get_local_version
-
-
-def list_configurations() -> None:
- """List the current platform configuration overrides, if any."""
- current_version = get_current_version()
- if not current_version:
- print_info("No current platform version is set. Please start or download the platform.")
- return
-
- overrides_env_file = current_version.configure_overrides_env_file
- if not overrides_env_file.exists():
- print_info("No configuration overrides found.")
- return
-
- print_info(f"Configuration overrides from {overrides_env_file}:")
- env_vars = read_env_file(overrides_env_file)
- for key, value in env_vars.items():
- print_info(f" - {key}={value}")
-
-
-def configure_platform(tag: str | None = None, **env_overrides: str | None) -> None:
- """Configure the platform for a specific service.
-
- Args:
- service: The name of the service to configure.
- """
- selected_version = get_local_version(tag) if tag else get_current_version()
- # No need to mark current version on configure
-
- if not selected_version:
- print_info("No current platform version is set. Please start or download the platform.")
- return
-
- if env_overrides:
- print_info("Setting environment overrides...")
- build_docker_compose_override_file(SERVICES, selected_version)
- build_env_file(selected_version.configure_overrides_env_file, **env_overrides)
- print_info(
- f"Configuration written to {selected_version.local_path}. "
- "These will take effect the next time the platform is started."
- " You can modify or remove them at any time."
- )
diff --git a/dreadnode/cli/platform/constants.py b/dreadnode/cli/platform/constants.py
index 0b332953..5fe054d9 100644
--- a/dreadnode/cli/platform/constants.py
+++ b/dreadnode/cli/platform/constants.py
@@ -1,12 +1,16 @@
import typing as t
+from dreadnode.constants import DEFAULT_LOCAL_STORAGE_DIR
+
PlatformService = t.Literal["dreadnode-api", "dreadnode-ui"]
+PLATFORM_SERVICES = t.cast("list[PlatformService]", t.get_args(PlatformService))
API_SERVICE: PlatformService = "dreadnode-api"
UI_SERVICE: PlatformService = "dreadnode-ui"
-SERVICES: list[PlatformService] = [API_SERVICE, UI_SERVICE]
-VERSIONS_MANIFEST = "versions.json"
SupportedArchitecture = t.Literal["amd64", "arm64"]
-SUPPORTED_ARCHITECTURES: list[SupportedArchitecture] = ["amd64", "arm64"]
+SUPPORTED_ARCHITECTURES = t.cast("list[SupportedArchitecture]", t.get_args(SupportedArchitecture))
DEFAULT_DOCKER_PROJECT_NAME = "dreadnode-platform"
+
+PLATFORM_STORAGE_DIR = DEFAULT_LOCAL_STORAGE_DIR / "platform"
+VERSION_CONFIG_PATH = PLATFORM_STORAGE_DIR / "versions.json"
diff --git a/dreadnode/cli/platform/docker_.py b/dreadnode/cli/platform/docker_.py
deleted file mode 100644
index c44587d6..00000000
--- a/dreadnode/cli/platform/docker_.py
+++ /dev/null
@@ -1,546 +0,0 @@
-import json
-import subprocess
-import typing as t
-from dataclasses import dataclass
-from enum import Enum
-
-import yaml
-from pydantic import BaseModel, Field
-from yaml import safe_dump
-
-from dreadnode.cli.api import create_api_client
-from dreadnode.cli.platform.constants import DEFAULT_DOCKER_PROJECT_NAME, PlatformService
-from dreadnode.cli.platform.schemas import LocalVersionSchema
-from dreadnode.cli.platform.utils.env_mgmt import read_env_file
-from dreadnode.cli.platform.utils.printing import print_error, print_info, print_success
-
-DockerContainerState = t.Literal[
- "running", "exited", "paused", "restarting", "removing", "created", "dead"
-]
-
-
-# create a DockerError exception that I can catch
-class DockerError(Exception):
- pass
-
-
-class CaptureOutput(str, Enum):
- TRUE = "true"
- FALSE = "false"
-
-
-@dataclass
-class DockerImage:
- repository: str
- tag: str | None = None
- digest: str | None = None
-
- @classmethod
- def from_string(cls, image_string: str) -> "DockerImage":
- """
- Parse a Docker image string into repository, tag, and SHA components.
-
- Examples:
- - postgres:16 -> repository="postgres", tag="16", sha=None
- - minio/minio:latest -> repository="minio/minio", tag="latest", sha=None
- - image@sha256:abc123 -> repository="image", tag=None, sha="sha256:abc123"
- """
- # Check if there's a SHA digest (contains @)
- if "@" in image_string:
- repo_part, sha = image_string.split("@", 1)
- # Check if there's also a tag before the @
- if ":" in repo_part:
- repository, tag = repo_part.rsplit(":", 1)
- return cls(repository=repository, tag=tag, digest=sha)
- return cls(repository=repo_part, tag=None, digest=sha)
-
- # Check if there's a tag (contains :)
- if ":" in image_string:
- # Use rsplit to handle cases like registry.com:5000/image:tag
- repository, tag = image_string.rsplit(":", 1)
- return cls(repository=repository, tag=tag, digest=None)
-
- # Just repository name
- return cls(repository=image_string, tag=None, digest=None)
-
- def __str__(self) -> str:
- """Reconstruct the original image string format."""
- result = self.repository
- if self.tag:
- result += f":{self.tag}"
- if self.digest:
- result += f"@{self.digest}"
- return result
-
- def __eq__(self, other: object) -> bool:
- """Check if two DockerImage instances are equal.
-
- If they both have digests, compare digests.
- If they both have tags, compare tags.
-
- """
- if not isinstance(other, DockerImage):
- return False
- if self.repository != other.repository:
- return False
- if self.digest and other.digest:
- return self.digest == other.digest
- if self.tag and other.tag:
- return self.tag == other.tag
- return False
-
- def __ne__(self, other: object) -> bool:
- """Check if two DockerImage instances are not equal."""
- return not self.__eq__(other)
-
- def __hash__(self) -> int:
- """Generate a hash for the DockerImage instance."""
- if self.tag:
- return hash((self.repository, self.tag))
- if self.digest:
- return hash((self.repository, self.digest))
- return hash((self.repository,))
-
-
-class DockerPSResult(BaseModel):
- name: str = Field(..., alias="Name")
- exit_code: int = Field(..., alias="ExitCode")
- state: DockerContainerState = Field(..., alias="State")
- status: str = Field(..., alias="Status")
-
- @property
- def is_running(self) -> bool:
- return self.state == "running"
-
-
-def _build_docker_compose_base_command(
- selected_version: LocalVersionSchema,
-) -> list[str]:
- cmds = []
- compose_files = [selected_version.compose_file]
- env_files = [
- selected_version.api_env_file,
- selected_version.ui_env_file,
- ]
-
- if (
- selected_version.configure_overrides_compose_file.exists()
- and selected_version.configure_overrides_env_file.exists()
- ):
- compose_files.append(selected_version.configure_overrides_compose_file)
- env_files.append(selected_version.configure_overrides_env_file)
-
- for compose_file in compose_files:
- cmds.extend(["-f", compose_file.as_posix()])
-
- for profile in _get_profiles_to_enable(selected_version):
- cmds.extend(["--profile", profile])
-
- if selected_version.arg_overrides_env_file.exists():
- env_files.append(selected_version.arg_overrides_env_file)
-
- for env_file in env_files:
- cmds.extend(["--env-file", env_file.as_posix()])
- return cmds
-
-
-def _check_docker_installed() -> bool:
- """Check if Docker is installed on the system."""
- try:
- cmd = ["docker", "--version"]
- subprocess.run( # noqa: S603
- cmd,
- check=True,
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
-
- except subprocess.CalledProcessError:
- print_error("Docker is not installed. Please install Docker and try again.")
- return False
-
- return True
-
-
-def _check_docker_compose_installed() -> bool:
- """Check if Docker Compose is installed on the system."""
- try:
- cmd = ["docker", "compose", "--version"]
- subprocess.run( # noqa: S603
- cmd,
- check=True,
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
- except subprocess.CalledProcessError:
- print_error("Docker Compose is not installed. Please install Docker Compose and try again.")
- return False
- return True
-
-
-def get_required_service_names(selected_version: LocalVersionSchema) -> list[str]:
- """Get the list of require service names from the docker-compose file."""
- contents: dict[str, t.Any] = yaml.safe_load(selected_version.compose_file.read_text())
- services = contents.get("services", {}) or {}
- return [name for name, cfg in services.items() if isinstance(cfg, dict) and "x-required" in cfg]
-
-
-def _get_profiles_to_enable(selected_version: LocalVersionSchema) -> list[str]:
- """Get the list of profiles to enable based on environment variables.
-
- If any of the `x-profile-disabled-vars` are set in the environment,
- the profile will be disabled.
-
- E.g.
-
- services:
- myservice:
- image: myimage:latest
- profiles:
- - myprofile
- x-profile-override-vars:
- - MY_SERVICE_HOST
-
- If MY_SERVICE_HOST is set in the environment, the `myprofile` profile
- will NOT be excluded from the docker compose --profile cmd.
-
- Args:
- selected_version: The selected version of the platform.
-
- Returns:
- List of profile names to enable.
- """
-
- contents: dict[str, t.Any] = yaml.safe_load(selected_version.compose_file.read_text())
- services = contents.get("services", {}) or {}
- profiles_to_enable: set[str] = set()
- for service in services.values():
- if not isinstance(service, dict):
- continue
- profiles = service.get("profiles", [])
- if not profiles or not isinstance(profiles, list):
- continue
- x_override_vars = service.get("x-profile-override-vars", [])
- if not x_override_vars or not isinstance(x_override_vars, list):
- profiles_to_enable.update(profiles)
- continue
-
- configuration_file = selected_version.configure_overrides_env_file
- overrides_file = selected_version.arg_overrides_env_file
- env_vars = {}
- if configuration_file.exists():
- env_vars.update(read_env_file(configuration_file))
- if overrides_file.exists():
- env_vars.update(read_env_file(overrides_file))
- # check if any of the override vars are set in the env
- if any(var in env_vars for var in x_override_vars):
- continue # skip enabling this profile
- profiles_to_enable.update(profiles)
-
- return list(profiles_to_enable)
-
-
-def _run_docker_compose_command(
- args: list[str],
- timeout: int = 300,
- stdin_input: str | None = None,
- capture_output: CaptureOutput | None = None,
-) -> subprocess.CompletedProcess[str]:
- """Execute a docker compose command with common error handling and configuration.
-
- Args:
- args: Additional arguments for the docker compose command.
- compose_file: Path to docker-compose file.
- timeout: Command timeout in seconds.
- command_name: Name of the command for error messages.
- stdin_input: Input to pass to stdin (for commands like docker login).
-
- Returns:
- CompletedProcess object with command results.
-
- Raises:
- subprocess.CalledProcessError: If command fails.
- subprocess.TimeoutExpired: If command times out.
- FileNotFoundError: If docker/docker-compose not found.
- """
- cmd = ["docker", "compose"]
-
- cmd.extend(["-p", DEFAULT_DOCKER_PROJECT_NAME])
-
- # Add the specific command arguments
- cmd.extend(args)
-
- cmd_str = " ".join(cmd)
-
- try:
- # Remove capture_output=True to allow real-time streaming
- # stdout and stderr will go directly to the terminal
- result = subprocess.run( # noqa: S603
- cmd,
- check=True,
- text=True,
- timeout=timeout,
- encoding="utf-8",
- errors="replace",
- input=stdin_input,
- capture_output=bool(capture_output == CaptureOutput.TRUE),
- )
-
- except subprocess.CalledProcessError as e:
- print_error(f"{cmd_str} failed with exit code {e.returncode}")
- raise DockerError(f"Docker command failed: {e}") from e
-
- except subprocess.TimeoutExpired as e:
- print_error(f"{cmd_str} timed out after {timeout} seconds")
- raise DockerError(f"Docker command timed out after {timeout} seconds") from e
-
- except FileNotFoundError as e:
- print_error("Docker or docker compose not found. Please ensure Docker is installed.")
- raise DockerError(f"Docker compose file not found: {e}") from e
-
- return result
-
-
-def build_docker_compose_override_file(
- services: list[PlatformService],
- selected_version: LocalVersionSchema,
-) -> None:
- # build a yaml docker compose override file
- # that only includes the service being configured
- # and has an `env_file` attribute for the service
- override = {
- "services": {
- f"{service}": {"env_file": [selected_version.configure_overrides_env_file.as_posix()]}
- for service in services
- },
- }
-
- with selected_version.configure_overrides_compose_file.open("w") as f:
- safe_dump(override, f, sort_keys=False)
-
-
-def get_available_local_images() -> list[DockerImage]:
- """Get the list of available Docker images on the local system.
-
- Returns:
- list[str]: List of available Docker image names.
- """
- cmd = ["docker", "images", "--format", "{{.Repository}}:{{.Tag}}@{{.Digest}}"]
- cp = subprocess.run( # noqa: S603
- cmd,
- check=True,
- text=True,
- capture_output=True,
- )
- images: list[DockerImage] = []
- for line in cp.stdout.splitlines():
- if line.strip():
- img = DockerImage.from_string(line.strip())
- images.append(img)
- return images
-
-
-def get_env_var_from_container(container_name: str, var_name: str) -> str | None:
- """
- Get the specified environment variable from the container and return
- its value.
-
- Args:
- container_name: Name of the container to inspect.
- var_name: Name of the environment variable to retrieve.
-
- Returns:
- str | None: Value of the environment variable, or None if not found.
- """
- try:
- cmd = [
- "docker",
- "inspect",
- "-f",
- "{{range .Config.Env}}{{println .}}{{end}}",
- container_name,
- ]
- cp = subprocess.run( # noqa: S603
- cmd,
- check=True,
- text=True,
- capture_output=True,
- )
-
- for line in cp.stdout.splitlines():
- if line.startswith(f"{var_name.upper()}="):
- return line.split("=", 1)[1]
-
- except subprocess.CalledProcessError:
- return None
-
- return None
-
-
-def get_required_images(selected_version: LocalVersionSchema) -> list[DockerImage]:
- """Get the list of required Docker images for the specified platform version.
-
- Args:
- selected_version: The selected version of the platform.
-
- Returns:
- list[str]: List of required Docker image names.
- """
- base = _build_docker_compose_base_command(selected_version)
- args = [*base, "config", "--images"]
- result = _run_docker_compose_command(
- args,
- timeout=120,
- capture_output=CaptureOutput.TRUE,
- )
-
- if result.returncode != 0:
- return []
-
- required_images: list[DockerImage] = []
- for line in result.stdout.splitlines():
- if not line.strip():
- continue
- # Validate each line is a valid Docker image string
- DockerImage.from_string(line.strip())
- required_images.append(DockerImage.from_string(line.strip()))
-
- return required_images
-
-
-def docker_requirements_met() -> bool:
- """Check if Docker and Docker Compose are installed."""
- return _check_docker_installed() and _check_docker_compose_installed()
-
-
-def docker_login(registry: str) -> None:
- """Log into a Docker registry using API credentials.
-
- Args:
- registry: Registry hostname to log into.
-
- Raises:
- subprocess.CalledProcessError: If docker login command fails.
- """
-
- print_info(f"Logging in to Docker registry: {registry} ...")
- client = create_api_client()
- container_registry_creds = client.get_container_registry_credentials()
-
- cmd = ["docker", "login", container_registry_creds.registry]
- cmd.extend(["--username", container_registry_creds.username])
- cmd.extend(["--password-stdin"])
-
- try:
- subprocess.run( # noqa: S603
- cmd,
- input=container_registry_creds.password,
- text=True,
- check=True,
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
- print_success("Logged in to container registry ...")
- except subprocess.CalledProcessError as e:
- print_error(f"Failed to log in to container registry: {e}")
- raise
-
-
-def docker_ps(
- selected_version: LocalVersionSchema,
- timeout: int = 120,
-) -> list[DockerPSResult]:
- """Get container status for the compose project as JSON.
-
- This mirrors:
- docker compose -f <...> -f <...> --env-file <...> --env-file <...> ps --format json [SERVICE...]
-
- Args:
- selected_version: Version object providing compose/env files.
- services: Optional list of PlatformService to filter (translated to 'platform-').
- timeout: Command timeout in seconds.
-
- Returns:
- A list of dicts parsed from `docker compose ps --format json`.
-
- Raises:
- ValueError: If the returned output is not valid JSON.
- subprocess.CalledProcessError / TimeoutExpired / FileNotFoundError: On execution errors.
- """
- base = _build_docker_compose_base_command(selected_version)
- args = [*base, "ps", "--format", "json"]
-
- result = _run_docker_compose_command(
- args,
- timeout=timeout,
- capture_output=CaptureOutput.TRUE,
- )
-
- try:
- # docker compose ps --format json returns a JSON array
- if not result.stdout:
- return []
- stdout = str(result.stdout)
- stdout_lines = stdout.splitlines()
- container_info_models: list[DockerPSResult] = []
- for line in stdout_lines:
- if not line.strip():
- continue
- j = json.loads(line)
- dpr = DockerPSResult(**j)
- container_info_models.append(dpr)
- except json.JSONDecodeError as e:
- print_error(f"Failed read status from the Dreadnode Platform': {e}")
- raise ValueError("Unexpected non-JSON output from 'docker compose ps'") from e
-
- return container_info_models
-
-
-def docker_run(
- selected_version: LocalVersionSchema,
- timeout: int = 300,
-) -> subprocess.CompletedProcess[str]:
- """Run docker containers for the platform.
-
- Args:
- compose_file: Path to docker-compose file.
- timeout: Command timeout in seconds.
-
- Returns:
- CompletedProcess object with command results.
-
- Raises:
- subprocess.CalledProcessError: If command fails.
- subprocess.TimeoutExpired: If command times out.
- """
- cmds = _build_docker_compose_base_command(selected_version)
-
- # Apply the compose and env override files in priority order
- # 1. base compose file and env files
- # 2. configure overrides compose and env files (if any)
- # 3. arg overrides env file (if any)
-
- cmds += ["up", "-d"]
- return _run_docker_compose_command(cmds, timeout=timeout)
-
-
-def docker_stop(
- selected_version: LocalVersionSchema,
- timeout: int = 300,
-) -> subprocess.CompletedProcess[str]:
- """Stop docker containers for the platform.
-
- Args:
- selected_version: The selected version of the platform.
- timeout: Command timeout in seconds.
-
- Returns:
- CompletedProcess object with command results.
-
- Raises:
- subprocess.CalledProcessError: If command fails.
- subprocess.TimeoutExpired: If command times out.
- """
- cmds = _build_docker_compose_base_command(selected_version)
- cmds.append("down")
- return _run_docker_compose_command(cmds, timeout=timeout)
diff --git a/dreadnode/cli/platform/download.py b/dreadnode/cli/platform/download.py
index 18f50943..1bb5ea14 100644
--- a/dreadnode/cli/platform/download.py
+++ b/dreadnode/cli/platform/download.py
@@ -1,144 +1,95 @@
import io
-import json
import zipfile
-from dreadnode.api.models import RegistryImageDetails
from dreadnode.cli.api import create_api_client
-from dreadnode.cli.platform.constants import SERVICES, VERSIONS_MANIFEST
-from dreadnode.cli.platform.schemas import LocalVersionSchema
-from dreadnode.cli.platform.utils.env_mgmt import (
+from dreadnode.cli.docker import docker_run
+from dreadnode.cli.platform.compose import compose_login, get_compose_args
+from dreadnode.cli.platform.constants import PLATFORM_SERVICES, PLATFORM_STORAGE_DIR
+from dreadnode.cli.platform.env_mgmt import (
create_default_env_files,
)
-from dreadnode.cli.platform.utils.printing import (
- print_error,
+from dreadnode.cli.platform.tag import add_tag_arch_suffix
+from dreadnode.cli.platform.version import (
+ LocalVersion,
+ VersionConfig,
+)
+from dreadnode.logging_ import (
+ confirm,
print_info,
print_success,
- print_warning,
-)
-from dreadnode.cli.platform.utils.versions import (
- confirm_with_context,
- create_local_latest_tag,
- get_available_local_versions,
- get_cli_version,
- get_local_cache_dir,
)
-def _resolve_latest(tag: str) -> str:
- """Resolve 'latest' tag to actual version tag from API.
-
- Args:
- tag: Version tag that contains 'latest'.
-
- Returns:
- str: Resolved actual version tag.
+def download_platform(tag: str | None = None) -> LocalVersion:
"""
- api_client = create_api_client()
- release_info = api_client.get_platform_releases(
- tag, services=[str(service) for service in SERVICES], cli_version=get_cli_version()
- )
- return release_info.tag
-
-
-def _create_local_version_file_structure(
- tag: str, release_info: RegistryImageDetails
-) -> LocalVersionSchema:
- """Create local file structure and update manifest for a new version.
+ Download platform version if not already available locally.
Args:
- tag: Version tag to create structure for.
- release_info: Registry image details from API.
+ tag: Version tag to download (supports 'latest').
Returns:
- LocalVersionSchema: Created local version schema.
+ LocalVersionSchema: Local version schema for the downloaded/existing version.
"""
- available_local_versions = get_available_local_versions()
-
- # Create a new local version schema
- local_cache_dir = get_local_cache_dir()
- new_version = LocalVersionSchema(
- **release_info.model_dump(),
- local_path=local_cache_dir / tag,
- current=False,
- )
-
- # Add the new version to the available local versions
- available_local_versions.versions.append(new_version)
-
- # sort the manifest by semver, newest first
- available_local_versions.versions.sort(key=lambda v: v.tag, reverse=True)
-
- # update the manifest file
- manifest_path = local_cache_dir / VERSIONS_MANIFEST
- with manifest_path.open(encoding="utf-8", mode="w") as f:
- json.dump(available_local_versions.model_dump(), f, indent=2)
+ version_config = VersionConfig.read()
+ api_client = create_api_client()
- print_success(f"Updated versions manifest at {manifest_path} with {new_version.tag}")
+ # 1 - Resolve the tag
- if new_version.local_path.exists():
- print_warning(f"Version {tag} already exists locally.")
- if not confirm_with_context("overwrite it?"):
- print_error("Aborting download.")
- return new_version
+ tag = tag or "latest"
+ tag = add_tag_arch_suffix(tag)
- # create the directory
- new_version.local_path.mkdir(parents=True, exist_ok=True)
+ if "latest" in tag:
+ tag = api_client.get_platform_releases(
+ tag, services=[str(service) for service in PLATFORM_SERVICES]
+ ).tag
- return new_version
+ # 2 - Check if the version is already available locally
+ if version_config.versions:
+ for available_local_version in version_config.versions:
+ if tag == available_local_version.tag:
+ print_success(f"[cyan]{tag}[/] is already downloaded.")
+ return available_local_version
-def _download_version_files(tag: str) -> LocalVersionSchema:
- """Download platform version files from API and extract locally.
+ # 3 - Download and check release info
- Args:
- tag: Version tag to download.
+ print_info(f"Downloading [cyan]{tag}[/] ...")
- Returns:
- LocalVersionSchema: Downloaded local version schema.
- """
api_client = create_api_client()
release_info = api_client.get_platform_releases(
- tag, services=[str(service) for service in SERVICES], cli_version=get_cli_version()
+ tag, services=[str(service) for service in PLATFORM_SERVICES]
)
- zip_content = api_client.get_platform_templates(tag)
- new_local_version = _create_local_version_file_structure(release_info.tag, release_info)
-
- with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_file:
- zip_file.extractall(new_local_version.local_path)
- print_success(f"Downloaded version {tag} to {new_local_version.local_path}")
+ new_version = LocalVersion(
+ **release_info.model_dump(),
+ local_path=PLATFORM_STORAGE_DIR / tag,
+ current=False,
+ )
- create_default_env_files(new_local_version)
- return new_local_version
+ if new_version.local_path.exists() and not confirm(
+ f"{new_version.local_path} exists, overwrite?"
+ ):
+ return new_version
+ version_config.add_version(new_version)
-def download_platform(tag: str | None = None) -> LocalVersionSchema:
- """Download platform version if not already available locally.
+ # 4 - Pull the release zip and extract it
- Args:
- tag: Version tag to download (supports 'latest').
+ zip_content = api_client.get_platform_templates(tag)
+ new_version.local_path.mkdir(parents=True, exist_ok=True)
- Returns:
- LocalVersionSchema: Local version schema for the downloaded/existing version.
- """
- if not tag or tag == "latest":
- # all remote images are tagged with architecture
- tag = create_local_latest_tag()
+ with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_file:
+ zip_file.extractall(new_version.local_path)
+ print_success(f"Downloaded [cyan]{tag}[/] to {new_version.local_path}")
- if "latest" in tag:
- tag = _resolve_latest(tag)
+ create_default_env_files(new_version)
- # get what's available
- available_local_versions = get_available_local_versions()
+ # 5 - Pull the images
- # if there are versions available
- if available_local_versions.versions:
- for available_local_version in available_local_versions.versions:
- if tag == available_local_version.tag:
- print_success(
- f"Version {tag} is already downloaded at {available_local_version.local_path}"
- )
- return available_local_version
+ print_info(f"Pulling Docker images for [cyan]{tag}[/] ...")
+ compose_login(new_version, force=True)
+ docker_run(
+ ["compose", *get_compose_args(new_version), "pull"],
+ )
- print_info(f"Version {tag} is not available locally. Attempting to download it now ...")
- return _download_version_files(tag)
+ return new_version
diff --git a/dreadnode/cli/platform/utils/env_mgmt.py b/dreadnode/cli/platform/env_mgmt.py
similarity index 96%
rename from dreadnode/cli/platform/utils/env_mgmt.py
rename to dreadnode/cli/platform/env_mgmt.py
index c0303a06..fc935445 100644
--- a/dreadnode/cli/platform/utils/env_mgmt.py
+++ b/dreadnode/cli/platform/env_mgmt.py
@@ -1,13 +1,13 @@
-import subprocess
+import subprocess # nosec
import sys
import typing as t
from pathlib import Path
from dreadnode.cli.platform.constants import (
- SERVICES,
+ PLATFORM_SERVICES,
)
-from dreadnode.cli.platform.schemas import LocalVersionSchema
-from dreadnode.cli.platform.utils.printing import print_error, print_info
+from dreadnode.cli.platform.version import LocalVersion
+from dreadnode.logging_ import print_error, print_info
LineTypes = t.Literal["variable", "comment", "empty"]
@@ -54,7 +54,8 @@ def _parse_env_lines(content: str) -> list[_EnvLine]:
def _extract_variables(lines: list[_EnvLine]) -> dict[str, str]:
- """Extract just the variables from parsed lines.
+ """
+ Extract just the variables from parsed lines.
Args:
lines: List of parsed environment file lines.
@@ -72,7 +73,8 @@ def _extract_variables(lines: list[_EnvLine]) -> dict[str, str]:
def _find_insertion_points(
base_lines: list[_EnvLine], remote_lines: list[_EnvLine], new_vars: dict[str, str]
) -> dict[str, int]:
- """Find the best insertion points for new variables based on remote file structure.
+ """
+ Find the best insertion points for new variables based on remote file structure.
Args:
base_lines: Lines from local file.
@@ -148,7 +150,8 @@ def _find_insertion_points(
def _reconstruct_env_content( # noqa: PLR0912
base_lines: list[_EnvLine], merged_vars: dict[str, str], updated_remote_lines: list[_EnvLine]
) -> str:
- """Reconstruct .env content preserving structure from base while applying merged variables.
+ """
+ Reconstruct .env content preserving structure from base while applying merged variables.
Args:
base_lines: Parsed lines from the local file (for structure).
@@ -246,7 +249,7 @@ def _reconstruct_env_content( # noqa: PLR0912
return "\n".join(result_lines)
-def create_default_env_files(current_version: LocalVersionSchema) -> None:
+def create_default_env_files(current_version: LocalVersion) -> None:
"""Create default environment files for all services in the current version.
Copies sample environment files to actual environment files if they don't exist,
@@ -258,7 +261,7 @@ def create_default_env_files(current_version: LocalVersionSchema) -> None:
Raises:
RuntimeError: If sample environment files are not found or .env file creation fails.
"""
- for service in SERVICES:
+ for service in PLATFORM_SERVICES:
for image in current_version.images:
if image.service == service:
env_file_path = current_version.get_env_path_by_service(service)
@@ -287,7 +290,7 @@ def open_env_file(filename: Path) -> None:
else:
cmd = ["xdg-open", filename.as_posix()]
try:
- subprocess.run(cmd, check=False) # noqa: S603
+ subprocess.run(cmd, check=False) # noqa: S603 # nosec
print_info("Opened environment file.")
except subprocess.CalledProcessError as e:
print_error(f"Failed to open environment file: {e}")
diff --git a/dreadnode/cli/platform/login.py b/dreadnode/cli/platform/login.py
deleted file mode 100644
index ab020199..00000000
--- a/dreadnode/cli/platform/login.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from dreadnode.cli.platform.docker_ import docker_login
-from dreadnode.cli.platform.utils.printing import print_info
-from dreadnode.cli.platform.utils.versions import get_current_version
-
-
-def log_into_registries() -> None:
- """Log into all Docker registries for the current platform version.
-
- Iterates through all images in the current version and logs into their
- respective registries. If no current version is set, displays an error message.
- """
- current_version = get_current_version()
- if not current_version:
- print_info("There are no registries configured. Run `dreadnode platform start` to start.")
- return
- registries_attempted = set()
- for image in current_version.images:
- if image.registry not in registries_attempted:
- docker_login(image.registry)
- registries_attempted.add(image.registry)
diff --git a/dreadnode/cli/platform/schemas.py b/dreadnode/cli/platform/schemas.py
deleted file mode 100644
index 1bca04d1..00000000
--- a/dreadnode/cli/platform/schemas.py
+++ /dev/null
@@ -1,117 +0,0 @@
-from pathlib import Path
-
-from pydantic import BaseModel, field_serializer
-
-from dreadnode.api.models import RegistryImageDetails
-from dreadnode.cli.platform.constants import API_SERVICE, UI_SERVICE
-
-
-class LocalVersionSchema(RegistryImageDetails):
- local_path: Path
- current: bool
-
- def __str__(self) -> str:
- return self.tag
-
- @field_serializer("local_path")
- def serialize_path(self, path: Path) -> str:
- """Serialize Path object to absolute path string.
-
- Args:
- path: Path object to serialize.
-
- Returns:
- str: Absolute path as string.
- """
- return str(path.resolve()) # Convert to absolute path string
-
- @property
- def details(self) -> str:
- configured_overrides = (
- "\n".join(
- f" - {line}" for line in self.configure_overrides_env_file.read_text().splitlines()
- )
- if self.configure_overrides_env_file.exists()
- else " (none)"
- )
-
- return (
- f"Tag: {self.tag}\n"
- f"Local Path: {self.local_path}\n"
- f"Compose File: {self.compose_file}\n"
- f"API Env File: {self.api_env_file}\n"
- f"UI Env File: {self.ui_env_file}\n"
- f"Configured: \n{configured_overrides}\n"
- )
-
- @property
- def compose_file(self) -> Path:
- return self.local_path / "docker-compose.yaml"
-
- @property
- def api_env_file(self) -> Path:
- return self.local_path / f".{API_SERVICE}.env"
-
- @property
- def api_example_env_file(self) -> Path:
- return self.local_path / f".{API_SERVICE}.example.env"
-
- @property
- def ui_env_file(self) -> Path:
- return self.local_path / f".{UI_SERVICE}.env"
-
- @property
- def ui_example_env_file(self) -> Path:
- return self.local_path / f".{UI_SERVICE}.example.env"
-
- @property
- def configure_overrides_env_file(self) -> Path:
- return self.local_path / ".configure.overrides.env"
-
- @property
- def configure_overrides_compose_file(self) -> Path:
- return self.local_path / "docker-compose.configure.overrides.yaml"
-
- @property
- def arg_overrides_env_file(self) -> Path:
- return self.local_path / ".arg.overrides.env"
-
- def get_env_path_by_service(self, service: str) -> Path:
- """Get environment file path for a specific service.
-
- Args:
- service: Service name to get env path for.
-
- Returns:
- Path: Path to the service's environment file.
-
- Raises:
- ValueError: If service is not recognized.
- """
- if service == API_SERVICE:
- return self.api_env_file
- if service == UI_SERVICE:
- return self.ui_env_file
- raise ValueError(f"Unknown service: {service}")
-
- def get_example_env_path_by_service(self, service: str) -> Path:
- """Get example environment file path for a specific service.
-
- Args:
- service: Service name to get example env path for.
-
- Returns:
- Path: Path to the service's example environment file.
-
- Raises:
- ValueError: If service is not recognized.
- """
- if service == API_SERVICE:
- return self.api_example_env_file
- if service == UI_SERVICE:
- return self.ui_example_env_file
- raise ValueError(f"Unknown service: {service}")
-
-
-class LocalVersionsSchema(BaseModel):
- versions: list[LocalVersionSchema]
diff --git a/dreadnode/cli/platform/start.py b/dreadnode/cli/platform/start.py
deleted file mode 100644
index f643c2d8..00000000
--- a/dreadnode/cli/platform/start.py
+++ /dev/null
@@ -1,79 +0,0 @@
-from dreadnode.cli.platform.docker_ import (
- DockerError,
- docker_login,
- docker_requirements_met,
- docker_run,
- docker_stop,
- get_available_local_images,
- get_env_var_from_container,
- get_required_images,
-)
-from dreadnode.cli.platform.download import download_platform
-from dreadnode.cli.platform.status import platform_is_running
-from dreadnode.cli.platform.utils.env_mgmt import write_overrides_env
-from dreadnode.cli.platform.utils.printing import print_error, print_info, print_success
-from dreadnode.cli.platform.utils.versions import (
- create_local_latest_tag,
- get_current_version,
- mark_current_version,
-)
-
-
-def start_platform(tag: str | None = None, **env_overrides: str) -> None:
- """Start the platform with the specified or current version.
-
- Args:
- tag: Optional image tag to use. If not provided, uses the current
- version or downloads the latest available version.
- """
- if not docker_requirements_met():
- print_error("Docker and Docker Compose must be installed to start the platform.")
- return
-
- if tag:
- selected_version = download_platform(tag)
- mark_current_version(selected_version)
- elif current_version := get_current_version():
- selected_version = current_version
- # no need to mark
- else:
- latest_tag = create_local_latest_tag()
- selected_version = download_platform(latest_tag)
- mark_current_version(selected_version)
-
- is_running = platform_is_running(selected_version)
- if is_running:
- print_info(f"Platform {selected_version.tag} is already running.")
- print_info("Use `dreadnode platform stop` to stop it first.")
- return
-
- # check to see if all required images are available locally
- required_images = get_required_images(selected_version)
- available_images = get_available_local_images()
- missing_images = [img for img in required_images if img not in available_images]
- if missing_images:
- registries_attempted = set()
- for image in selected_version.images:
- if image.registry not in registries_attempted:
- docker_login(image.registry)
- registries_attempted.add(image.registry)
-
- if env_overrides:
- write_overrides_env(selected_version.arg_overrides_env_file, **env_overrides)
-
- print_info(f"Starting platform: {selected_version.tag}")
- try:
- docker_run(selected_version)
- print_success(f"Platform {selected_version.tag} started successfully.")
- origin = get_env_var_from_container("dreadnode-ui", "ORIGIN")
- if origin:
- print_info("You can access the app at the following URLs:")
- print_info(f" - {origin}")
- else:
- print_info(" - Unable to determine the app URL.")
- print_info("Please check the container logs for more information.")
- except DockerError as e:
- print_error(f"Failed to start platform {selected_version.tag}: {e}")
- print_info("Stopping any partially started containers...")
- docker_stop(selected_version)
- print_info("You can check the logs for more details.")
diff --git a/dreadnode/cli/platform/status.py b/dreadnode/cli/platform/status.py
deleted file mode 100644
index 02b29dfd..00000000
--- a/dreadnode/cli/platform/status.py
+++ /dev/null
@@ -1,43 +0,0 @@
-from dreadnode.cli.platform.docker_ import docker_ps, get_required_service_names
-from dreadnode.cli.platform.schemas import LocalVersionSchema
-from dreadnode.cli.platform.utils.printing import print_error, print_success
-from dreadnode.cli.platform.utils.versions import get_current_version, get_local_version
-
-
-def platform_is_running(selected_version: LocalVersionSchema) -> bool:
- """Check if the platform with the specified or current version is running.
-
- Args:
- tag: Optional image tag to use. If not provided, uses the current
- version or downloads the latest available version.
- """
- required_services = get_required_service_names(selected_version)
- container_details = docker_ps(selected_version)
- if not container_details:
- return False
- for service in required_services:
- if service not in [c.name for c in container_details if c.status == "running"]:
- return False
- return True
-
-
-def platform_status(tag: str | None = None) -> bool:
- """Get the status of the platform with the specified or current version.
-
- Args:
- tag: Optional image tag to use. If not provided, uses the current
- version or downloads the latest available version.
- """
- if tag:
- selected_version = get_local_version(tag)
- elif current_version := get_current_version():
- selected_version = current_version
- else:
- print_error("No current platform version is set. Please start or download the platform.")
- return False
- required_containers_running = platform_is_running(selected_version)
- if required_containers_running:
- print_success(f"Platform {selected_version.tag} is running.")
- else:
- print_error(f"Platform {selected_version.tag} is not fully running.")
- return required_containers_running
diff --git a/dreadnode/cli/platform/stop.py b/dreadnode/cli/platform/stop.py
deleted file mode 100644
index 1079cfb2..00000000
--- a/dreadnode/cli/platform/stop.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from dreadnode.cli.platform.docker_ import docker_stop
-from dreadnode.cli.platform.utils.env_mgmt import remove_overrides_env
-from dreadnode.cli.platform.utils.printing import print_error, print_success
-from dreadnode.cli.platform.utils.versions import (
- get_current_version,
-)
-
-
-def stop_platform() -> None:
- """Stop the currently running platform.
-
- Uses the current version's compose file to stop all platform containers
- via docker compose down.
- """
- current_version = get_current_version()
- if not current_version:
- print_error("No current version found. Nothing to stop.")
- return
- remove_overrides_env(current_version.arg_overrides_env_file)
- docker_stop(current_version)
- print_success("Platform stopped successfully.")
diff --git a/dreadnode/cli/platform/tag.py b/dreadnode/cli/platform/tag.py
new file mode 100644
index 00000000..e93ed723
--- /dev/null
+++ b/dreadnode/cli/platform/tag.py
@@ -0,0 +1,40 @@
+import platform
+
+from packaging.version import Version
+
+from dreadnode.cli.platform.constants import SUPPORTED_ARCHITECTURES
+
+
+def add_tag_arch_suffix(tag: str) -> str:
+ """
+ Add architecture suffix to a tag if it doesn't already have one.
+
+ Args:
+ tag: The original tag string.
+ """
+ if any(tag.endswith(f"-{arch}") for arch in SUPPORTED_ARCHITECTURES):
+ return tag # Tag already has a supported architecture suffix
+
+ arch = platform.machine()
+
+ if arch in ["x86_64", "AMD64"]:
+ arch = "amd64"
+ elif arch in ["arm64", "aarch64", "ARM64"]:
+ arch = "arm64"
+ else:
+ raise ValueError(f"Unsupported architecture: {arch}")
+
+ return f"{tag}-{arch}"
+
+
+def tag_to_semver(tag: str) -> Version:
+ """
+ Extract semantic version from a tag by removing architecture suffix.
+
+ Args:
+ tag: The tag string that may contain an architecture suffix.
+
+ Returns:
+ A packaging Version object representing the semantic version.
+ """
+ return Version(tag.split("-")[0].removeprefix("v"))
diff --git a/dreadnode/cli/platform/upgrade.py b/dreadnode/cli/platform/upgrade.py
deleted file mode 100644
index e267bdee..00000000
--- a/dreadnode/cli/platform/upgrade.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from dreadnode.cli.platform.docker_ import docker_stop
-from dreadnode.cli.platform.download import download_platform
-from dreadnode.cli.platform.start import start_platform
-from dreadnode.cli.platform.utils.printing import print_error, print_info
-from dreadnode.cli.platform.utils.versions import (
- confirm_with_context,
- create_local_latest_tag,
- get_current_version,
- get_semver_from_tag,
- mark_current_version,
- newer_remote_version,
-)
-
-
-def upgrade_platform() -> None:
- """Upgrade the platform to the latest available version.
-
- Downloads the latest version, compares it with the current version,
- and performs the upgrade if a newer version is available. Optionally
- merges configuration files from the current version to the new version.
- Stops the current platform and starts the upgraded version.
- """
- latest_tag = create_local_latest_tag()
-
- latest_version = download_platform(latest_tag)
- current_local_version = get_current_version()
-
- if not current_local_version:
- print_error(
- "No current platform version found. Run `dreadnode platform start` to start the latest version."
- )
- return
-
- current_semver = get_semver_from_tag(current_local_version.tag)
- remote_semver = get_semver_from_tag(latest_version.tag)
-
- if not newer_remote_version(current_semver, remote_semver):
- print_info(f"You are using the latest ({current_semver}) version of the platform.")
- return
-
- if not confirm_with_context(
- f"Are you sure you want to upgrade from {current_local_version.tag} to {latest_version.tag}?"
- ):
- print_error("Aborting upgrade.")
- return
-
- # copy the configuration overrides from the current version to the new version
- if (
- current_local_version.configure_overrides_compose_file.exists()
- and current_local_version.configure_overrides_env_file.exists()
- ):
- latest_version.configure_overrides_compose_file.write_text(
- current_local_version.configure_overrides_compose_file.read_text()
- )
- latest_version.configure_overrides_env_file.write_text(
- current_local_version.configure_overrides_env_file.read_text()
- )
-
- print_info(f"Stopping current platform version {current_local_version.tag}...")
- docker_stop(current_local_version)
- print_info(f"Current platform version {current_local_version.tag} stopped.")
-
- mark_current_version(latest_version)
- start_platform()
diff --git a/dreadnode/cli/platform/utils/__init__.py b/dreadnode/cli/platform/utils/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/dreadnode/cli/platform/utils/printing.py b/dreadnode/cli/platform/utils/printing.py
deleted file mode 100644
index 2691e44e..00000000
--- a/dreadnode/cli/platform/utils/printing.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import sys
-
-import rich
-
-
-def print_success(message: str, prefix: str | None = None) -> None:
- """Print success message in green"""
- prefix = prefix or "✓"
- rich.print(f"[bold green]{prefix}[/] [green]{message}[/]")
-
-
-def print_error(message: str, prefix: str | None = None) -> None:
- """Print error message in red"""
- prefix = prefix or "✗"
- rich.print(f"[bold red]{prefix}[/] [red]{message}[/]", file=sys.stderr)
-
-
-def print_warning(message: str, prefix: str | None = None) -> None:
- """Print warning message in yellow"""
- prefix = prefix or "⚠"
- rich.print(f"[bold yellow]{prefix}[/] [yellow]{message}[/]")
-
-
-def print_info(message: str, prefix: str | None = None) -> None:
- """Print info message in blue"""
- prefix = prefix or "i"
- rich.print(f"[bold blue]{prefix}[/] [blue]{message}[/]")
-
-
-def print_debug(message: str, prefix: str | None = None) -> None:
- """Print debug message in dim gray"""
- prefix = prefix or "🐛"
- rich.print(f"[dim]{prefix}[/] [dim]{message}[/]")
-
-
-def print_heading(message: str) -> None:
- """Print section heading"""
- rich.print(f"\n[bold underline]{message}[/]\n")
-
-
-def print_muted(message: str) -> None:
- """Print muted text"""
- rich.print(f"[dim]{message}[/]")
diff --git a/dreadnode/cli/platform/utils/versions.py b/dreadnode/cli/platform/utils/versions.py
deleted file mode 100644
index 17226cd9..00000000
--- a/dreadnode/cli/platform/utils/versions.py
+++ /dev/null
@@ -1,184 +0,0 @@
-import importlib.metadata
-import json
-import platform
-from pathlib import Path
-
-from packaging.version import Version
-from rich.prompt import Confirm
-
-from dreadnode.cli.platform.constants import (
- SUPPORTED_ARCHITECTURES,
- VERSIONS_MANIFEST,
- SupportedArchitecture,
-)
-from dreadnode.cli.platform.schemas import LocalVersionSchema, LocalVersionsSchema
-from dreadnode.constants import DEFAULT_LOCAL_STORAGE_DIR
-
-
-def _get_local_arch() -> SupportedArchitecture:
- """Get the local machine architecture in supported format.
-
- Returns:
- SupportedArchitecture: The architecture as either "amd64" or "arm64".
-
- Raises:
- ValueError: If the local architecture is not supported.
- """
- arch = platform.machine()
-
- if arch in ["x86_64", "AMD64"]:
- return "amd64"
- if arch in ["arm64", "aarch64", "ARM64"]:
- return "arm64"
- raise ValueError(f"Unsupported architecture: {arch}")
-
-
-def get_local_cache_dir() -> Path:
- """Get the local cache directory path for dreadnode platform files.
-
- Returns:
- Path: Path to the local cache directory (~//platform).
- """
- return DEFAULT_LOCAL_STORAGE_DIR / "platform"
-
-
-def get_cli_version() -> str:
- """Get the version of the dreadnode CLI package.
-
- Returns:
- str | None: The version string of the dreadnode package, or None if not found.
- """
- return importlib.metadata.version("dreadnode")
-
-
-def confirm_with_context(action: str) -> bool:
- """Prompt the user for confirmation with a formatted action message.
-
- Args:
- action: The action description to display in the confirmation prompt.
-
- Returns:
- bool: True if the user confirms, False otherwise. Defaults to False.
- """
- return Confirm.ask(f"[bold blue]{action}[/bold blue]", default=False)
-
-
-def get_available_local_versions() -> LocalVersionsSchema:
- """Get all available local platform versions from the manifest file.
-
- Creates the manifest file with an empty schema if it doesn't exist.
-
- Returns:
- LocalVersionsSchema: Schema containing all available local platform versions.
- """
- try:
- local_cache_dir = get_local_cache_dir()
- manifest_path = local_cache_dir / VERSIONS_MANIFEST
- with manifest_path.open(encoding="utf-8") as f:
- versions_manifest_data = json.load(f)
- return LocalVersionsSchema(**versions_manifest_data)
- except FileNotFoundError:
- # create the file
- local_cache_dir = get_local_cache_dir()
- manifest_path = local_cache_dir / VERSIONS_MANIFEST
- manifest_path.parent.mkdir(parents=True, exist_ok=True)
- blank_schema = LocalVersionsSchema(versions=[])
- with manifest_path.open(encoding="utf-8", mode="w") as f:
- json.dump(blank_schema.model_dump(), f)
- return blank_schema
-
-
-def get_current_version() -> LocalVersionSchema | None:
- """Get the currently active local platform version.
-
- Returns:
- LocalVersionSchema | None: The current version schema if one is marked as current,
- None otherwise.
- """
- available_local_versions = get_available_local_versions()
- if not available_local_versions.versions:
- return None
- for version in available_local_versions.versions:
- if version.current:
- return version
- return None
-
-
-def get_local_version(tag: str) -> LocalVersionSchema:
- """Get a specific local platform version by its tag.
-
- Args:
- tag: The tag of the version to retrieve.
-
- Returns:
- LocalVersionSchema: The version schema matching the provided tag.
-
- Raises:
- ValueError: If no version with the specified tag is found.
- """
- available_local_versions = get_available_local_versions()
- for version in available_local_versions.versions:
- if version.tag == tag:
- return version
- raise ValueError(f"No local version found with tag: {tag}")
-
-
-def mark_current_version(current_version: LocalVersionSchema) -> None:
- """Mark a specific version as the current active version.
-
- Updates the versions manifest to mark the specified version as current
- and all others as not current.
-
- Args:
- current_version: The version to mark as current.
- """
- available_local_versions = get_available_local_versions()
- for available_version in available_local_versions.versions:
- if available_version.tag == current_version.tag:
- available_version.current = True
- else:
- available_version.current = False
-
- local_cache_dir = get_local_cache_dir()
- manifest_path = local_cache_dir / VERSIONS_MANIFEST
- with manifest_path.open(encoding="utf-8", mode="w") as f:
- json.dump(available_local_versions.model_dump(), f, indent=2)
-
-
-def create_local_latest_tag() -> str:
- """Create a latest tag string for the local architecture.
-
- Returns:
- str: A tag in the format "latest-{arch}" where arch is the local architecture.
- """
- arch = _get_local_arch()
- return f"latest-{arch}"
-
-
-def get_semver_from_tag(tag: str) -> str:
- """Extract semantic version from a tag by removing architecture suffix.
-
- Args:
- tag: The tag string that may contain an architecture suffix.
-
- Returns:
- str: The tag with any supported architecture suffix removed.
- """
- for arch in SUPPORTED_ARCHITECTURES:
- if arch in tag:
- return tag.replace(f"-{arch}", "")
- return tag
-
-
-def newer_remote_version(local_version: str, remote_version: str) -> bool:
- """Check if the remote version is newer than the local version.
-
- Args:
- local_version: The local version string in semantic version format.
- remote_version: The remote version string in semantic version format.
-
- Returns:
- bool: True if the remote version is newer than the local version, False otherwise.
- """
- # compare the semvers of two versions to see if the remote is "newer"
- return Version(remote_version) > Version(local_version)
diff --git a/dreadnode/cli/platform/version.py b/dreadnode/cli/platform/version.py
new file mode 100644
index 00000000..605643bc
--- /dev/null
+++ b/dreadnode/cli/platform/version.py
@@ -0,0 +1,216 @@
+from pathlib import Path
+
+from pydantic import BaseModel, field_serializer
+
+from dreadnode.api.models import RegistryImageDetails
+from dreadnode.cli.platform.constants import (
+ API_SERVICE,
+ UI_SERVICE,
+ VERSION_CONFIG_PATH,
+)
+from dreadnode.cli.platform.tag import tag_to_semver
+
+
+class LocalVersion(RegistryImageDetails):
+ local_path: Path
+ current: bool
+
+ def __str__(self) -> str:
+ return self.tag
+
+ @field_serializer("local_path")
+ def serialize_path(self, path: Path) -> str:
+ """Serialize Path object to absolute path string.
+
+ Args:
+ path: Path object to serialize.
+
+ Returns:
+ str: Absolute path as string.
+ """
+ return str(path.resolve()) # Convert to absolute path string
+
+ @property
+ def details(self) -> str:
+ configured_overrides = (
+ "\n".join(
+ f" - {line}" for line in self.configure_overrides_env_file.read_text().splitlines()
+ )
+ if self.configure_overrides_env_file.exists()
+ else " (none)"
+ )
+
+ return (
+ f"Tag: {self.tag}\n"
+ f"Local Path: {self.local_path}\n"
+ f"Compose File: {self.compose_file}\n"
+ f"API Env File: {self.api_env_file}\n"
+ f"UI Env File: {self.ui_env_file}\n"
+ f"Configured: \n{configured_overrides}\n"
+ )
+
+ @property
+ def compose_file(self) -> Path:
+ return self.local_path / "docker-compose.yaml"
+
+ @property
+ def api_env_file(self) -> Path:
+ return self.local_path / f".{API_SERVICE}.env"
+
+ @property
+ def api_example_env_file(self) -> Path:
+ return self.local_path / f".{API_SERVICE}.example.env"
+
+ @property
+ def ui_env_file(self) -> Path:
+ return self.local_path / f".{UI_SERVICE}.env"
+
+ @property
+ def ui_example_env_file(self) -> Path:
+ return self.local_path / f".{UI_SERVICE}.example.env"
+
+ @property
+ def configure_overrides_env_file(self) -> Path:
+ return self.local_path / ".configure.overrides.env"
+
+ @property
+ def configure_overrides_compose_file(self) -> Path:
+ return self.local_path / "docker-compose.configure.overrides.yaml"
+
+ @property
+ def arg_overrides_env_file(self) -> Path:
+ return self.local_path / ".arg.overrides.env"
+
+ def get_env_path_by_service(self, service: str) -> Path:
+ """Get environment file path for a specific service.
+
+ Args:
+ service: Service name to get env path for.
+
+ Returns:
+ Path: Path to the service's environment file.
+
+ Raises:
+ ValueError: If service is not recognized.
+ """
+ if service == API_SERVICE:
+ return self.api_env_file
+ if service == UI_SERVICE:
+ return self.ui_env_file
+ raise ValueError(f"Unknown service: {service}")
+
+ def get_example_env_path_by_service(self, service: str) -> Path:
+ """Get example environment file path for a specific service.
+
+ Args:
+ service: Service name to get example env path for.
+
+ Returns:
+ Path: Path to the service's example environment file.
+
+ Raises:
+ ValueError: If service is not recognized.
+ """
+ if service == API_SERVICE:
+ return self.api_example_env_file
+ if service == UI_SERVICE:
+ return self.ui_example_env_file
+ raise ValueError(f"Unknown service: {service}")
+
+
+class VersionConfig(BaseModel):
+ versions: list[LocalVersion]
+
+ @classmethod
+ def read(cls) -> "VersionConfig":
+ """Read the version configuration from the file system or return an empty instance."""
+
+ if not VERSION_CONFIG_PATH.exists():
+ return cls(versions=[])
+
+ with VERSION_CONFIG_PATH.open("r") as f:
+ return cls.model_validate_json(f.read())
+
+ def write(self) -> None:
+ """Write the versions configuration to the file system."""
+
+ if not VERSION_CONFIG_PATH.parent.exists():
+ VERSION_CONFIG_PATH.parent.mkdir(parents=True)
+
+ with VERSION_CONFIG_PATH.open("w") as f:
+ f.write(self.model_dump_json())
+
+ def add_version(self, version: LocalVersion) -> None:
+ """
+ Add a new version to the configuration if it doesn't already exist.
+
+ Args:
+ version: The LocalVersion instance to add.
+ """
+ if next((v for v in self.versions if v.tag == version.tag), None) is None:
+ self.versions.append(version)
+ self.write()
+
+ def get_current_version(self, *, tag: str | None = None) -> LocalVersion | None:
+ """Get the current active version or a specific version by tag."""
+ if tag:
+ return next((v for v in self.versions if v.tag == tag), None)
+
+ if current := next((v for v in self.versions if v.current), None):
+ return current
+
+ if latest := self.get_latest_version():
+ self.set_current_version(latest)
+ return latest
+
+ return None
+
+ def get_latest_version(self) -> LocalVersion | None:
+ """Get the latest version based on semantic versioning."""
+ if not self.versions:
+ return None
+ sorted_versions = sorted(
+ self.versions,
+ key=lambda v: tag_to_semver(v.tag),
+ reverse=True,
+ )
+ return sorted_versions[0]
+
+ def get_by_tag(self, tag: str) -> LocalVersion:
+ """
+ Get a specific local platform version by its tag.
+
+ Args:
+ tag: The tag of the version to retrieve.
+
+ Returns:
+ LocalVersion: The version schema matching the provided tag.
+
+ Raises:
+ ValueError: If no version with the specified tag is found.
+ """
+ for version in self.versions:
+ if version.tag == tag:
+ return version
+ raise ValueError(f"No local version found with tag: {tag}")
+
+ def set_current_version(self, version: LocalVersion) -> None:
+ """
+ Mark a specific version as the current active version.
+
+ Updates the versions manifest to mark the specified version as current
+ and all others as not current.
+
+ Args:
+ version: The version to mark as current.
+ """
+ if next((v for v in self.versions if v.tag == version.tag), None) is None:
+ self.versions.append(version)
+
+ for available_version in self.versions:
+ if available_version.tag == version.tag:
+ available_version.current = True
+ else:
+ available_version.current = False
+
+ self.write()
diff --git a/dreadnode/cli/profile/cli.py b/dreadnode/cli/profile/cli.py
index 9ed1ccf8..ce2ac3a2 100644
--- a/dreadnode/cli/profile/cli.py
+++ b/dreadnode/cli/profile/cli.py
@@ -1,12 +1,12 @@
import typing as t
import cyclopts
-import rich
from rich import box
from rich.prompt import Prompt
from rich.table import Table
from dreadnode.cli.api import Token
+from dreadnode.logging_ import console, print_error, print_info, print_success
from dreadnode.user_config import UserConfig
from dreadnode.util import time_to
@@ -19,7 +19,7 @@ def show() -> None:
config = UserConfig.read()
if not config.servers:
- rich.print(":exclamation: No server profiles are configured")
+ print_error("No server profiles are configured")
return
table = Table(box=box.ROUNDED)
@@ -44,7 +44,7 @@ def show() -> None:
style="bold" if active else None,
)
- rich.print(table)
+ console.print(table)
@cli.command()
@@ -55,37 +55,39 @@ def switch(
config = UserConfig.read()
if not config.servers:
- rich.print(":exclamation: No server profiles are configured")
+ print_error("No server profiles are configured")
return
# If no profile provided, prompt user to choose
if profile is None:
profiles = list(config.servers.keys())
- rich.print("\nAvailable profiles:")
+ print_info("Available profiles:")
for i, p in enumerate(profiles, 1):
active_marker = " (current)" if p == config.active else ""
- rich.print(f" {i}. [bold orange_red1]{p}[/]{active_marker}")
+ print_info(f" {i}. [bold orange_red1]{p}[/]{active_marker}")
choice = Prompt.ask(
"\nSelect a profile",
choices=[str(i) for i in range(1, len(profiles) + 1)] + profiles,
show_choices=False,
+ console=console,
)
profile = profiles[int(choice) - 1] if choice.isdigit() else choice
if profile not in config.servers:
- rich.print(f":exclamation: Profile [bold]{profile}[/] does not exist")
+ print_error(f"Profile [bold]{profile}[/] does not exist")
return
config.active = profile
config.write()
- rich.print(f":laptop_computer: Switched to [bold orange_red1]{profile}[/]")
- rich.print(f"|- email: [bold]{config.servers[profile].email}[/]")
- rich.print(f"|- username: {config.servers[profile].username}")
- rich.print(f"|- url: {config.servers[profile].url}")
- rich.print()
+ print_success(
+ f"Switched to [bold orange_red1]{profile}[/]\n"
+ f"|- email: [bold]{config.servers[profile].email}[/]\n"
+ f"|- username: {config.servers[profile].username}\n"
+ f"|- url: {config.servers[profile].url}\n"
+ )
@cli.command()
@@ -95,10 +97,10 @@ def forget(
"""Remove a server profile from the configuration."""
config = UserConfig.read()
if profile not in config.servers:
- rich.print(f":exclamation: Profile [bold]{profile}[/] does not exist")
+ print_error(f"Profile [bold]{profile}[/] does not exist")
return
del config.servers[profile]
config.write()
- rich.print(f":axe: Forgot about [bold]{profile}[/]")
+ print_success(f"Forgot about [bold]{profile}[/]")
diff --git a/dreadnode/cli/shared.py b/dreadnode/cli/shared.py
index 080223b4..2f988698 100644
--- a/dreadnode/cli/shared.py
+++ b/dreadnode/cli/shared.py
@@ -3,7 +3,7 @@
import cyclopts
-from dreadnode.logging_ import LogLevelLiteral, configure_logging
+from dreadnode.logging_ import LogLevel, configure_logging
@cyclopts.Parameter(name="dn", group="Dreadnode")
@@ -19,7 +19,7 @@ class DreadnodeConfig:
"""Profile name"""
console: t.Annotated[bool, cyclopts.Parameter(negative=False)] = False
"""Show spans in the console"""
- log_level: LogLevelLiteral | None = None
+ log_level: LogLevel | None = None
"""Console log level"""
def apply(self) -> None:
diff --git a/dreadnode/logging_.py b/dreadnode/logging_.py
index 09b01a87..b0bb0602 100644
--- a/dreadnode/logging_.py
+++ b/dreadnode/logging_.py
@@ -5,35 +5,32 @@
"""
import pathlib
-import sys
import typing as t
+from textwrap import dedent
from loguru import logger
+from rich.console import Console
+from rich.logging import RichHandler
+from rich.prompt import Confirm
+from rich.theme import Theme
-if t.TYPE_CHECKING:
- from loguru import Record as LogRecord
-
-g_configured: bool = False
-
-LogLevelList = ["trace", "debug", "info", "success", "warning", "error", "critical"]
-LogLevelLiteral = t.Literal["trace", "debug", "info", "success", "warning", "error", "critical"]
+LogLevel = t.Literal["trace", "debug", "info", "success", "warning", "error", "critical"]
"""Valid logging levels."""
-
-def log_formatter(record: "LogRecord") -> str:
- return "".join(
- (
- "{time:HH:mm:ss.SSS} | ",
- "{extra[prefix]} " if record["extra"].get("prefix") else "",
- "{message}\n",
- )
- )
+console = Console(
+ theme=Theme( # rich doesn't include default colors for these
+ {
+ "logging.level.success": "green",
+ "logging.level.trace": "dim blue",
+ }
+ ),
+)
def configure_logging(
- log_level: LogLevelLiteral = "info",
+ log_level: LogLevel = "info",
log_file: pathlib.Path | None = None,
- log_file_level: LogLevelLiteral = "debug",
+ log_file_level: LogLevel = "debug",
) -> None:
"""
Configures common loguru handlers.
@@ -44,26 +41,40 @@ def configure_logging(
will only be done to the console.
log_file_level: The log level for the log file.
"""
- global g_configured # noqa: PLW0603
-
- if g_configured:
- return
-
logger.enable("dreadnode")
- logger.level("TRACE", color="", icon="[T]")
- logger.level("DEBUG", color="", icon="[_]")
- logger.level("INFO", color="", icon="[=]")
- logger.level("SUCCESS", color="", icon="[+]")
- logger.level("WARNING", color="", icon="[-]")
- logger.level("ERROR", color="", icon="[!]")
- logger.level("CRITICAL", color="", icon="[x]")
-
logger.remove()
- logger.add(sys.stderr, format=log_formatter, level=log_level.upper())
+ logger.add(
+ RichHandler(console=console, log_time_format="%X"),
+ format=lambda _: "{message}",
+ level=log_level.upper(),
+ )
if log_file is not None:
- logger.add(log_file, format=log_formatter, level=log_file_level.upper())
+ logger.add(log_file, level=log_file_level.upper())
logger.info(f"Logging to {log_file}")
- g_configured = True
+
+def print_success(message: str) -> None:
+ console.print(f"[bold green]:heavy_check_mark:[/] {dedent(message)}")
+
+
+def print_error(message: str) -> None:
+ console.print(f"[bold red]:heavy_multiplication_x:[/] {dedent(message)}")
+
+
+def print_warning(message: str) -> None:
+ console.print(f"[bold yellow]:warning:[/] {dedent(message)}")
+
+
+def print_info(message: str) -> None:
+ console.print(dedent(message))
+
+
+def confirm(action: str) -> bool:
+ return Confirm.ask(
+ f"[bold magenta]:left-right_arrow:[/] {action}",
+ default=False,
+ case_sensitive=False,
+ console=console,
+ )
diff --git a/dreadnode/main.py b/dreadnode/main.py
index 54da44de..5293654e 100644
--- a/dreadnode/main.py
+++ b/dreadnode/main.py
@@ -10,7 +10,6 @@
import coolname # type: ignore [import-untyped]
import logfire
-import rich
from fsspec.implementations.local import ( # type: ignore [import-untyped]
LocalFileSystem,
)
@@ -41,6 +40,7 @@
ENV_SERVER_URL,
)
from dreadnode.error import AssertionFailedError
+from dreadnode.logging_ import console as logging_console
from dreadnode.metric import (
Metric,
MetricAggMode,
@@ -277,11 +277,13 @@ def configure(
# Log config information for clarity
if self.server or self.token or self.local_dir:
destination = self.server or DEFAULT_SERVER_URL or "local storage"
- rich.print(f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})")
+ logging_console.print(
+ f"Dreadnode logging to [orange_red1]{destination}[/] ({config_source})"
+ )
# Warn the user if the profile didn't resolve
elif active_profile and not (self.server or self.token):
- rich.print(
+ logging_console.print(
f":exclamation: Dreadnode profile [orange_red1]{active_profile}[/] appears invalid."
)
diff --git a/dreadnode/user_config.py b/dreadnode/user_config.py
index f1daa806..2131910f 100644
--- a/dreadnode/user_config.py
+++ b/dreadnode/user_config.py
@@ -1,8 +1,8 @@
-import rich
from pydantic import BaseModel
from ruamel.yaml import YAML
from dreadnode.constants import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH
+from dreadnode.logging_ import print_info
class ServerConfig(BaseModel):
@@ -62,7 +62,7 @@ def write(self) -> None:
self._update_active()
if not USER_CONFIG_PATH.parent.exists():
- rich.print(f":rocket: Creating config at {USER_CONFIG_PATH.parent}")
+ print_info(f"Creating config at {USER_CONFIG_PATH.parent}")
USER_CONFIG_PATH.parent.mkdir(parents=True)
with USER_CONFIG_PATH.open("w") as f:
diff --git a/pyproject.toml b/pyproject.toml
index 3701e5c2..717db1c2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -182,7 +182,7 @@ skip-magic-trailing-comma = false
]
"tests/**/*.py" = [
"INP001", # namespace not required for pytest
- "S101", # asserts allowed in tests...
"SLF001", # allow access to private members
"PLR2004", # magic values
+ "S1", # security issues in tests are not relevant
]
diff --git a/tests/cli/test_config.py b/tests/cli/test_config.py
new file mode 100644
index 00000000..a82a9457
--- /dev/null
+++ b/tests/cli/test_config.py
@@ -0,0 +1,79 @@
+from pathlib import Path
+
+import pydantic
+import pytest
+
+from dreadnode.user_config import ServerConfig, UserConfig
+
+
+def test_server_config() -> None:
+ # Test valid server config
+ config = ServerConfig(
+ url="https://platform.dreadnode.io",
+ email="test@example.com",
+ username="test",
+ api_key="test123", # pragma: allowlist secret
+ access_token="token123",
+ refresh_token="refresh123",
+ )
+ assert config.url == "https://platform.dreadnode.io"
+ assert config.email == "test@example.com"
+ assert config.username == "test"
+ assert config.api_key == "test123" # pragma: allowlist secret
+ assert config.access_token == "token123"
+ assert config.refresh_token == "refresh123"
+
+ # Test invalid server config model
+ with pytest.raises(pydantic.ValidationError):
+ ServerConfig.model_validate({"invalid": "data"})
+
+
+def test_user_config(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
+ # Mock config path to use temporary directory
+ mock_config_path = tmp_path / "config.yaml"
+ monkeypatch.setattr("dreadnode.user_config.USER_CONFIG_PATH", mock_config_path)
+
+ # Test empty config
+ config = UserConfig()
+ assert config.active is None
+ assert config.servers == {}
+
+ # Test adding server config
+ server_config = ServerConfig(
+ url="https://platform.dreadnode.io",
+ email="test@example.com",
+ username="test",
+ api_key="test123", # pragma: allowlist secret
+ access_token="token123",
+ refresh_token="refresh123",
+ )
+
+ config.set_server_config(server_config, "default")
+ assert "default" in config.servers
+ assert config.get_server_config("default") == server_config
+
+ # Test active profile
+ config.active = "default"
+ assert config.get_server_config() == server_config
+
+ # Test writing and reading config
+ config.write()
+ assert mock_config_path.exists()
+
+ loaded_config = UserConfig.read()
+ assert loaded_config.active == "default"
+ assert loaded_config.servers["default"] == server_config
+
+ # Test invalid profile access
+ with pytest.raises(RuntimeError):
+ config.get_server_config("nonexistent")
+
+ # Test auto-setting active profile
+ config.active = None
+ config._update_active()
+ assert config.active == "default"
+
+ # Test empty config edge case
+ empty_config = UserConfig()
+ empty_config._update_active()
+ assert empty_config.active is None
diff --git a/tests/cli/test_docker.py b/tests/cli/test_docker.py
new file mode 100644
index 00000000..eea20641
--- /dev/null
+++ b/tests/cli/test_docker.py
@@ -0,0 +1,161 @@
+import pytest
+
+from dreadnode.cli.docker import DockerImage
+
+
+@pytest.mark.parametrize(
+ ("input_str", "expected_repo", "expected_tag"),
+ [
+ ("ubuntu", "library/ubuntu", "latest"),
+ ("postgres", "library/postgres", "latest"),
+ ("alpine:3.18", "library/alpine", "3.18"),
+ ("redis:7", "library/redis", "7"),
+ ],
+)
+def test_docker_image_official_images(
+ input_str: str, expected_repo: str, expected_tag: str
+) -> None:
+ """Tests parsing of official Docker Hub images (e.g., 'ubuntu')."""
+ img = DockerImage(input_str)
+ assert img.registry is None # Should be None if not provided
+ assert img.repository == expected_repo
+ assert img.tag == expected_tag
+ assert img.digest is None
+
+
+@pytest.mark.parametrize(
+ ("input_str", "expected_repo", "expected_tag"),
+ [
+ ("dreadnode/image", "dreadnode/image", "latest"),
+ ("bitnami/postgresql", "bitnami/postgresql", "latest"),
+ ("minio/minio:RELEASE.2023-03-20T20-16-18Z", "minio/minio", "RELEASE.2023-03-20T20-16-18Z"),
+ ],
+)
+def test_docker_image_namespaced_images(
+ input_str: str, expected_repo: str, expected_tag: str
+) -> None:
+ """Tests parsing of namespaced Docker Hub images (e.g., 'dreadnode/image')."""
+ img = DockerImage(input_str)
+ assert img.registry is None # Should be None if not provided
+ assert img.repository == expected_repo
+ assert img.tag == expected_tag
+ assert img.digest is None
+
+
+@pytest.mark.parametrize(
+ ("input_str", "expected_registry", "expected_repo", "expected_tag"),
+ [
+ ("gcr.io/google-samples/hello-app:1.0", "gcr.io", "google-samples/hello-app", "1.0"),
+ ("ghcr.io/owner/image:tag", "ghcr.io", "owner/image", "tag"),
+ ("localhost:5000/my-app", "localhost:5000", "my-app", "latest"),
+ ("my.registry:1234/a/b/c:v1", "my.registry:1234", "a/b/c", "v1"),
+ ],
+)
+def test_docker_image_with_custom_registry(
+ input_str: str, expected_registry: str, expected_repo: str, expected_tag: str
+) -> None:
+ """Tests parsing of images with a full registry hostname."""
+ img = DockerImage(input_str)
+ assert img.registry == expected_registry
+ assert img.repository == expected_repo
+ assert img.tag == expected_tag
+ assert img.digest is None
+
+
+@pytest.mark.parametrize(
+ ("input_str", "expected_repo", "expected_tag", "expected_digest"),
+ [
+ ("ubuntu@sha256:abc", "library/ubuntu", None, "sha256:abc"),
+ ("dreadnode/image@sha256:123", "dreadnode/image", None, "sha256:123"),
+ ("gcr.io/app/image@sha256:xyz", "app/image", None, "sha256:xyz"),
+ ("ubuntu:22.04@sha256:456", "library/ubuntu", "22.04", "sha256:456"),
+ ],
+)
+def test_docker_image_with_digest(
+ input_str: str, expected_repo: str, expected_tag: str | None, expected_digest: str
+) -> None:
+ """Tests parsing of images with a digest."""
+ img = DockerImage(input_str)
+ assert img.repository == expected_repo
+ assert img.tag == expected_tag
+ assert img.digest == expected_digest
+
+
+def test_docker_image_self_format_and_normalization() -> None:
+ """Test that DockerImage can handle its own normalized string outputs."""
+ img1 = DockerImage("ubuntu")
+ assert str(img1) == "library/ubuntu:latest"
+
+ img2 = DockerImage(str(img1))
+ assert img2.registry is None
+ assert img2.repository == "library/ubuntu"
+ assert img2.tag == "latest"
+ assert img1 == img2
+
+
+def test_docker_image_whitespace_handling() -> None:
+ """Test that leading/trailing whitespace is properly stripped."""
+ img = DockerImage(" ubuntu:22.04 \n")
+ assert img.repository == "library/ubuntu"
+ assert img.tag == "22.04"
+ assert str(img) == "library/ubuntu:22.04"
+
+
+@pytest.mark.parametrize(
+ "case",
+ [
+ "", # Empty string
+ " ", # Just whitespace
+ "@sha256:123", # Just a digest
+ ],
+)
+def test_docker_image_invalid_formats(case: str) -> None:
+ """Test that invalid formats raise ValueError."""
+ with pytest.raises(ValueError, match="Invalid Docker image format"):
+ DockerImage(case)
+
+
+def test_docker_image_string_methods_inheritance() -> None:
+ """Test that string methods from the str parent class work as expected."""
+ img = DockerImage("ubuntu:22.04")
+ assert str(img) == "library/ubuntu:22.04"
+ assert img.upper() == "LIBRARY/UBUNTU:22.04"
+ assert img.startswith("library")
+
+
+def test_docker_image_comparisons() -> None:
+ """Test comparison operations."""
+ img1 = DockerImage("ubuntu")
+ img2 = DockerImage("library/ubuntu:latest")
+ img3 = DockerImage("postgres")
+ img4 = DockerImage("docker.io/library/ubuntu:latest")
+
+ assert img1 == img2
+ assert img1 != img3
+ assert img1 != img4 # These are now different, as one specifies a registry
+ assert img1 == "library/ubuntu:latest"
+
+
+def test_docker_image_with_method() -> None:
+ """Tests the with_() method for creating modified copies."""
+ original = DockerImage("gcr.io/project/image:1.0")
+
+ # Change registry
+ with_new_registry = original.with_(registry="ghcr.io")
+ assert str(with_new_registry) == "ghcr.io/project/image:1.0"
+ assert with_new_registry.registry == "ghcr.io"
+
+ # Remove registry by setting to None
+ with_no_registry = original.with_(registry=None)
+ assert str(with_no_registry) == "project/image:1.0"
+ assert with_no_registry.registry is None
+
+ # Change tag and remove digest
+ with_digest = DockerImage("gcr.io/project/image:1.0@sha256:abc")
+ with_new_tag = with_digest.with_(tag="2.0-beta", digest=None)
+ assert str(with_new_tag) == "gcr.io/project/image:2.0-beta"
+ assert with_new_tag.tag == "2.0-beta"
+ assert with_new_tag.digest is None
+
+ # Ensure original object is not mutated
+ assert str(original) == "gcr.io/project/image:1.0"
diff --git a/tests/cli/test_github.py b/tests/cli/test_github.py
new file mode 100644
index 00000000..6fc680ba
--- /dev/null
+++ b/tests/cli/test_github.py
@@ -0,0 +1,201 @@
+import pytest
+
+from dreadnode.cli.github import GithubRepo
+
+
+def test_github_repo_simple_format() -> None:
+ repo = GithubRepo("owner/repo")
+ assert repo.namespace == "owner"
+ assert repo.repo == "repo"
+ assert repo.ref == "main"
+ assert str(repo) == "owner/repo@main"
+
+
+def test_github_repo_simple_format_with_ref() -> None:
+ repo = GithubRepo("owner/repo/tree/develop")
+ assert repo.namespace == "owner"
+ assert repo.repo == "repo"
+ assert repo.ref == "develop"
+ assert str(repo) == "owner/repo@develop"
+
+
+@pytest.mark.parametrize(
+ "case",
+ [
+ "https://github.com/owner/repo",
+ "http://github.com/owner/repo",
+ "https://github.com/owner/repo.git",
+ ],
+)
+def test_github_repo_https_url(case: str) -> None:
+ repo = GithubRepo(case)
+ assert repo.namespace == "owner"
+ assert repo.repo == "repo"
+ assert repo.ref == "main"
+ assert str(repo) == "owner/repo@main"
+
+
+@pytest.mark.parametrize(
+ ("case", "expected_ref"),
+ [
+ ("https://github.com/owner/repo/tree/feature/custom-branch", "feature/custom-branch"),
+ ("https://github.com/owner/repo/blob/feature/custom-branch", "feature/custom-branch"),
+ ],
+)
+def test_github_repo_https_url_with_ref(case: str, expected_ref: str) -> None:
+ repo = GithubRepo(case)
+ assert repo.namespace == "owner"
+ assert repo.repo == "repo"
+ assert repo.ref == expected_ref
+ assert str(repo) == f"owner/repo@{expected_ref}"
+
+
+@pytest.mark.parametrize(
+ "case",
+ [
+ "git@github.com:owner/repo",
+ "git@github.com:owner/repo.git",
+ ],
+)
+def test_github_repo_ssh_url(case: str) -> None:
+ repo = GithubRepo(case)
+ assert repo.namespace == "owner"
+ assert repo.repo == "repo"
+ assert repo.ref == "main"
+ assert str(repo) == "owner/repo@main"
+
+
+@pytest.mark.parametrize(
+ ("case", "expected_ref"),
+ [
+ ("https://raw.githubusercontent.com/owner/repo/main", "main"),
+ ("https://raw.githubusercontent.com/owner/repo/feature-branch", "feature-branch"),
+ ("https://raw.githubusercontent.com/owner/repo/feature/branch", "feature/branch"),
+ ],
+)
+def test_github_repo_raw_githubusercontent(case: str, expected_ref: str) -> None:
+ repo = GithubRepo(case)
+ assert repo.namespace == "owner"
+ assert repo.repo == "repo"
+ assert repo.ref == expected_ref
+ assert str(repo) == f"owner/repo@{expected_ref}"
+
+
+@pytest.mark.parametrize(
+ ("input_str", "expected_ref"),
+ [
+ ("owner/repo/tree/feature/custom", "feature/custom"),
+ ("owner/repo/releases/tag/v1.0.0", "v1.0.0"),
+ ],
+)
+def test_github_repo_ref_handling(input_str: str, expected_ref: str) -> None:
+ """Test handling of different reference formats"""
+ repo = GithubRepo(input_str)
+ assert repo.namespace == "owner"
+ assert repo.repo == "repo"
+ assert repo.ref == expected_ref
+ assert repo.zip_url == f"https://github.com/owner/repo/zipball/{expected_ref}"
+
+
+@pytest.mark.parametrize(
+ "case",
+ [
+ "owner/repo.js",
+ "https://github.com/owner/repo.js",
+ "git@github.com:owner/repo.js.git",
+ ],
+)
+def test_github_repo_with_dots(case: str) -> None:
+ """Test repositories with dots in names"""
+ repo = GithubRepo(case)
+ assert repo.namespace == "owner"
+ assert repo.repo == "repo.js"
+ assert str(repo) == "owner/repo.js@main"
+
+
+@pytest.mark.parametrize(
+ "case",
+ [
+ "owner-name/repo-name",
+ "https://github.com/owner-name/repo-name",
+ "git@github.com:owner-name/repo-name.git",
+ ],
+)
+def test_github_repo_with_dashes(case: str) -> None:
+ """Test repositories with dashes in names"""
+ repo = GithubRepo(case)
+ assert repo.namespace == "owner-name"
+ assert repo.repo == "repo-name"
+ assert str(repo) == "owner-name/repo-name@main"
+
+
+@pytest.mark.parametrize(
+ "case",
+ [
+ " owner/repo ",
+ "\nowner/repo\n",
+ "\towner/repo\t",
+ ],
+)
+def test_github_repo_whitespace_handling(case: str) -> None:
+ """Test that whitespace is properly stripped"""
+ repo = GithubRepo(case)
+ assert repo.namespace == "owner"
+ assert repo.repo == "repo"
+ assert str(repo) == "owner/repo@main"
+
+
+@pytest.mark.parametrize(
+ "case",
+ [
+ "", # Empty string
+ "owner", # Missing repo
+ "owner/", # Missing repo
+ "/repo", # Missing owner
+ "owner/repo/extra", # Too many parts
+ "http://gitlab.com/owner/repo", # Wrong domain
+ "git@gitlab.com:owner/repo.git", # Wrong domain
+ ],
+)
+def test_github_repo_invalid_formats(case: str) -> None:
+ """Test that invalid formats raise ValueError"""
+ with pytest.raises(ValueError, match="Invalid GitHub repository format"):
+ GithubRepo(case)
+
+
+def test_github_repo_string_methods_inheritance() -> None:
+ """Test that string methods work as expected"""
+ repo = GithubRepo("owner/repo")
+ assert repo.upper() == "OWNER/REPO@MAIN"
+ assert repo.split("/") == ["owner", "repo@main"]
+ assert repo.split("@") == ["owner/repo", "main"]
+ assert repo.replace("owner", "newowner") == "newowner/repo@main"
+ assert len(repo) == len("owner/repo@main")
+
+
+def test_github_repo_comparisons() -> None:
+ """Test comparison operations"""
+ repo1 = GithubRepo("owner/repo")
+ repo2 = GithubRepo("owner/repo")
+ repo3 = GithubRepo("different/repo")
+
+ assert repo1 == repo2
+ assert repo1 != repo3
+ assert repo1 == "owner/repo@main"
+
+
+def test_github_repo_self_format() -> None:
+ """Test that GithubRepo can handle its own string representations"""
+ # Test basic format
+ repo1 = GithubRepo("owner/repo@main")
+ assert repo1.namespace == "owner"
+ assert repo1.repo == "repo"
+ assert repo1.ref == "main"
+ assert str(repo1) == "owner/repo@main"
+
+ # Test creating from existing repo string
+ repo2 = GithubRepo(str(repo1))
+ assert repo2.namespace == "owner"
+ assert repo2.repo == "repo"
+ assert repo2.ref == "main"
+ assert str(repo2) == str(repo1)
diff --git a/uv.lock b/uv.lock
index 94a607e9..ef020a36 100644
--- a/uv.lock
+++ b/uv.lock
@@ -681,7 +681,7 @@ wheels = [
[[package]]
name = "dreadnode"
-version = "1.14.0"
+version = "1.14.1"
source = { editable = "." }
dependencies = [
{ name = "coolname" },