diff --git a/api/src/app/core/cdktf/ranges/base_range.py b/api/src/app/core/cdktf/ranges/base_range.py index c04088aa..4056ce7c 100644 --- a/api/src/app/core/cdktf/ranges/base_range.py +++ b/api/src/app/core/cdktf/ranges/base_range.py @@ -21,6 +21,7 @@ DeployedRangeSchema, ) from ....schemas.secret_schema import SecretSchema +from ....utils.cdktf_utils import gen_resource_logical_ids from ....utils.name_utils import normalize_name from ...config import settings from ..stacks.base_stack import AbstractBaseStack @@ -32,15 +33,6 @@ class AbstractBaseRange(ABC): """Abstract class to enforce common functionality across range cloud providers.""" - name: str - range_obj: BlueprintRangeSchema | DeployedRangeSchema - state_file: dict[str, Any] | None # Terraform state - region: OpenLabsRegion - stack_name: str - secrets: SecretSchema - deployed_range_name: str - description: str - # Mutex for terraform init calls _init_lock = asyncio.Lock() @@ -422,38 +414,42 @@ async def _parse_terraform_outputs( # noqa: PLR0911 dumped_schema["jumpbox_public_ip"] = raw_outputs[jumpbox_ip_key]["value"] dumped_schema["range_private_key"] = raw_outputs[private_key]["value"] + vpc_logical_ids = gen_resource_logical_ids( + [vpc.name for vpc in self.range_obj.vpcs] + ) for x, vpc in enumerate(self.range_obj.vpcs): - - normalized_vpc_name = normalize_name(vpc.name) - + vpc_logical_id = vpc_logical_ids[vpc.name] current_vpc = dumped_schema["vpcs"][x] + vpc_key = next( ( key for key in raw_outputs - if key.endswith(f"-{normalized_vpc_name}-resource-id") + if key.endswith(f"-{vpc_logical_id}-resource-id") ), None, ) if not vpc_key: logger.error( "Could not find VPC resource ID key for %s in Terraform output", - normalized_vpc_name, + vpc_logical_id, ) return None current_vpc["resource_id"] = raw_outputs[vpc_key]["value"] + subnet_logical_ids = gen_resource_logical_ids( + [subnet.name for subnet in vpc.subnets] # type: ignore + ) for y, subnet in enumerate(vpc.subnets): # type: ignore - - normalized_subnet_name = normalize_name(subnet.name) - + subnet_logical_id = subnet_logical_ids[subnet.name] current_subnet = current_vpc["subnets"][y] + subnet_key = next( ( key for key in raw_outputs if key.endswith( - f"-{normalized_vpc_name}-{normalized_subnet_name}-resource-id" + f"-{vpc_logical_id}-{subnet_logical_id}-resource-id" ) ), None, @@ -461,8 +457,8 @@ async def _parse_terraform_outputs( # noqa: PLR0911 if not subnet_key: logger.error( "Could not find subnet resource ID key for %s in %s in Terraform output", - normalized_subnet_name, - normalized_vpc_name, + subnet_logical_id, + vpc_logical_id, ) return None current_subnet["resource_id"] = raw_outputs[subnet_key]["value"] @@ -474,7 +470,7 @@ async def _parse_terraform_outputs( # noqa: PLR0911 key for key in raw_outputs if key.endswith( - f"-{normalized_vpc_name}-{normalized_subnet_name}-{host.hostname}-resource-id" + f"-{vpc_logical_id}-{subnet_logical_id}-{host.hostname}-resource-id" ) ), None, @@ -484,7 +480,7 @@ async def _parse_terraform_outputs( # noqa: PLR0911 key for key in raw_outputs if key.endswith( - f"-{normalized_vpc_name}-{normalized_subnet_name}-{host.hostname}-private-ip" + f"-{vpc_logical_id}-{subnet_logical_id}-{host.hostname}-private-ip" ) ), None, @@ -494,8 +490,8 @@ async def _parse_terraform_outputs( # noqa: PLR0911 logger.error( "Could not find host keys for %s in %s/%s in Terraform output", host.hostname, - normalized_vpc_name, - normalized_subnet_name, + vpc_logical_id, + subnet_logical_id, ) return None diff --git a/api/src/app/core/cdktf/stacks/aws_stack.py b/api/src/app/core/cdktf/stacks/aws_stack.py index b88ce0d9..98b6a8ad 100644 --- a/api/src/app/core/cdktf/stacks/aws_stack.py +++ b/api/src/app/core/cdktf/stacks/aws_stack.py @@ -22,8 +22,8 @@ from ....enums.regions import AWS_REGION_MAP, OpenLabsRegion from ....enums.specs import AWS_SPEC_MAP from ....schemas.range_schemas import BlueprintRangeSchema, DeployedRangeSchema +from ....utils.cdktf_utils import gen_resource_logical_ids from ....utils.crypto import generate_range_rsa_key_pair -from ....utils.name_utils import normalize_name from .base_stack import AbstractBaseStack @@ -267,23 +267,23 @@ def build_resources( ) # Create Range vpcs, subnets, hosts + vpc_logical_ids = gen_resource_logical_ids([vpc.name for vpc in range_obj.vpcs]) for vpc in range_obj.vpcs: - - normalized_vpc_name = normalize_name(vpc.name) + vpc_logical_id = vpc_logical_ids[vpc.name] # Step 14: Create a VPC new_vpc = Vpc( self, - f"{range_name}-{normalized_vpc_name}", + f"{range_name}-{vpc_logical_id}", cidr_block=str(vpc.cidr), enable_dns_support=True, enable_dns_hostnames=True, - tags={"Name": normalized_vpc_name}, + tags={"Name": vpc_logical_id}, ) TerraformOutput( self, - f"{range_name}-{normalized_vpc_name}-resource-id", + f"{range_name}-{vpc_logical_id}-resource-id", value=new_vpc.id, description="Cloud resource id of the vpc created", sensitive=True, @@ -294,13 +294,13 @@ def build_resources( # Every VPC will use the same secrutiy group but security groups are scoped to a single VPC, so they have to be added to each one private_vpc_sg = SecurityGroup( self, - f"{range_name}-{normalized_vpc_name}-SharedPrivateSG", + f"{range_name}-{vpc_logical_id}-SharedPrivateSG", vpc_id=new_vpc.id, tags={"Name": "RangePrivateInternalSecurityGroup"}, ) SecurityGroupRule( # Allow access from the Jumpbox - possibly not needed based on next rule self, - f"{range_name}-{normalized_vpc_name}-RangeAllowAllTrafficFromJumpBox-Rule", + f"{range_name}-{vpc_logical_id}-RangeAllowAllTrafficFromJumpBox-Rule", type="ingress", from_port=0, to_port=0, @@ -310,7 +310,7 @@ def build_resources( ) SecurityGroupRule( self, - f"{range_name}-{normalized_vpc_name}-RangeAllowInternalTraffic-Rule", # Allow all internal subnets to communicate with each other + f"{range_name}-{vpc_logical_id}-RangeAllowInternalTraffic-Rule", # Allow all internal subnets to communicate with each other type="ingress", from_port=0, to_port=0, @@ -320,7 +320,7 @@ def build_resources( ) SecurityGroupRule( self, - f"{range_name}-{normalized_vpc_name}-RangeAllowPrivateOutbound-Rule", + f"{range_name}-{vpc_logical_id}-RangeAllowPrivateOutbound-Rule", type="egress", from_port=0, to_port=0, @@ -330,23 +330,25 @@ def build_resources( ) current_vpc_subnets: list[Subnet] = [] + subnet_logical_ids = gen_resource_logical_ids( + [subnet.name for subnet in vpc.subnets] + ) # Step 16: Create private subnets with their respecitve EC2 instances for subnet in vpc.subnets: - - normalized_subnet_name = normalize_name(subnet.name) + subnet_logical_id = subnet_logical_ids[subnet.name] new_subnet = Subnet( self, - f"{range_name}-{normalized_vpc_name}-{normalized_subnet_name}", + f"{range_name}-{vpc_logical_id}-{subnet_logical_id}", vpc_id=new_vpc.id, cidr_block=str(subnet.cidr), availability_zone="us-east-1a", - tags={"Name": normalized_subnet_name}, + tags={"Name": subnet_logical_id}, ) TerraformOutput( self, - f"{range_name}-{normalized_vpc_name}-{normalized_subnet_name}-resource-id", + f"{range_name}-{vpc_logical_id}-{subnet_logical_id}-resource-id", value=new_subnet.id, description="Cloud resource id of the subnet created", sensitive=True, @@ -358,7 +360,7 @@ def build_resources( for host in subnet.hosts: ec2_instance = Instance( self, - f"{range_name}-{normalized_vpc_name}-{normalized_subnet_name}-{host.hostname}", + f"{range_name}-{vpc_logical_id}-{subnet_logical_id}-{host.hostname}", ami=AWS_OS_MAP[host.os], instance_type=AWS_SPEC_MAP[host.spec], subnet_id=new_subnet.id, @@ -369,14 +371,14 @@ def build_resources( TerraformOutput( self, - f"{range_name}-{normalized_vpc_name}-{normalized_subnet_name}-{host.hostname}-resource-id", + f"{range_name}-{vpc_logical_id}-{subnet_logical_id}-{host.hostname}-resource-id", value=ec2_instance.id, description="Cloud resource id of the ec2 instance created", sensitive=True, ) TerraformOutput( self, - f"{range_name}-{normalized_vpc_name}-{normalized_subnet_name}-{host.hostname}-private-ip", + f"{range_name}-{vpc_logical_id}-{subnet_logical_id}-{host.hostname}-private-ip", value=ec2_instance.private_ip, description="Cloud private IP address of the ec2 instance created", sensitive=True, @@ -385,7 +387,7 @@ def build_resources( # Step 17: Attach VPC to Transit Gateway private_vpc_tgw_attachment = Ec2TransitGatewayVpcAttachment( # noqa: F841 self, - f"{range_name}-{normalized_vpc_name}-PrivateVpcTgwAttachment", + f"{range_name}-{vpc_logical_id}-PrivateVpcTgwAttachment", subnet_ids=[ current_vpc_subnets[0].id ], # Attach TGW ENIs to all private subnets @@ -393,20 +395,20 @@ def build_resources( vpc_id=new_vpc.id, transit_gateway_default_route_table_association=True, transit_gateway_default_route_table_propagation=True, - tags={"Name": f"{normalized_vpc_name}-private-vpc-tgw-attachment"}, + tags={"Name": f"{vpc_logical_id}-private-vpc-tgw-attachment"}, ) # Step 18: Create Routing in range VPC (Routes to TGW to access other range VPCs or the internet via the NAT gateway) new_vpc_private_route_table = RouteTable( self, - f"{range_name}-{normalized_vpc_name}-PrivateRouteTable", + f"{range_name}-{vpc_logical_id}-PrivateRouteTable", vpc_id=new_vpc.id, - tags={"Name": f"{normalized_vpc_name}-private-route-table"}, + tags={"Name": f"{vpc_logical_id}-private-route-table"}, ) # Default route for range VPC to Transit Gateway tgw_route = Route( # noqa: F841 self, - f"{range_name}-{normalized_vpc_name}-PrivateTgwRoute", + f"{range_name}-{vpc_logical_id}-PrivateTgwRoute", route_table_id=new_vpc_private_route_table.id, destination_cidr_block="0.0.0.0/0", # All traffic goes to TGW transit_gateway_id=tgw.id, @@ -415,7 +417,7 @@ def build_resources( for i, created_subnet in enumerate(current_vpc_subnets): RouteTableAssociation( self, - f"{range_name}-{normalized_vpc_name}-PrivateSubnetRouteTableAssociation_{i+1}", + f"{range_name}-{vpc_logical_id}-PrivateSubnetRouteTableAssociation_{i+1}", subnet_id=str(created_subnet.id), route_table_id=new_vpc_private_route_table.id, ) @@ -425,7 +427,7 @@ def build_resources( # Add route to the Jumpbox VPC's Public route table (for Jumpbox access & NAT Return Traffic) Route( self, - f"{range_name}-{normalized_vpc_name}-PublicRtbToPrivateVpcRoute", + f"{range_name}-{vpc_logical_id}-PublicRtbToPrivateVpcRoute", route_table_id=jumpbox_route_table.id, # Route in the public subnet's RT destination_cidr_block=new_vpc.cidr_block, # Traffic destined to the range VPCs will go through the transit gateway transit_gateway_id=tgw.id, @@ -434,7 +436,7 @@ def build_resources( # This ensures traffic arriving *from* the TGW destined for another private VPC goes back *to* the TGW Route( self, - f"{range_name}-{normalized_vpc_name}-PublicVpcTgwSubnetRtbToPrivateVpcRoute", + f"{range_name}-{vpc_logical_id}-PublicVpcTgwSubnetRtbToPrivateVpcRoute", route_table_id=nat_route_table.id, # Route in the TGW attachment subnet's Route Table (jumpbox private subnet) destination_cidr_block=new_vpc.cidr_block, # Traffic destined to the range VPCs will go through the transit gateway transit_gateway_id=tgw.id, diff --git a/api/src/app/utils/cdktf_utils.py b/api/src/app/utils/cdktf_utils.py index 36924f72..163c0ceb 100644 --- a/api/src/app/utils/cdktf_utils.py +++ b/api/src/app/utils/cdktf_utils.py @@ -1,7 +1,63 @@ import tempfile +from collections import Counter, defaultdict + +from .name_utils import normalize_name def create_cdktf_dir() -> str: """Create temp dir for CDKTF.""" # /tmp/.openlabs-cdktf-XXXX return tempfile.mkdtemp(prefix=".openlabs-cdktf-") + + +def gen_resource_logical_ids(resource_names: list[str]) -> dict[str, str]: + """Generate deterministic, normalized, and unique logical IDs from a list of resource names. + + This function handles collisions that occur after normalization by appending + a numeric suffix. + + Args: + resource_names: A list of user-supplied resource names. + + Returns: + A dictionary mapping each original resource name to its unique logical ID. + + Example: + >>> names = ["Web Server", "Database", "web-server", "Auth Service"] + >>> gen_resource_logical_ids(names) + { + 'Auth Service': 'auth-service', + 'Database': 'database', + 'Web Server': 'web-server', + 'web-server': 'web-server-1' + } + + """ + if len(set(resource_names)) != len(resource_names): + counts = Counter(resource_names) + duplicates = [name for name, count in counts.items() if count > 1] + msg = f"Input list contains exact duplicate names: {', '.join(duplicates)}" + raise ValueError(msg) + + logical_ids: dict[str, str] = {} + seen_counts: defaultdict[str, int] = defaultdict(int) + + # Sorted to ensure deterministic ID generation + sorted_names = sorted(resource_names) + + for name in sorted_names: + base_id = normalize_name(name) + + # The first time we see a base_id, its ID is just itself + # each time we see it again we add a suffix + if seen_counts[base_id] > 0: + logical_id = f"{base_id}-{seen_counts[base_id]}" + else: + logical_id = base_id + + logical_ids[name] = logical_id + + # Increment the count for the next potential collision + seen_counts[base_id] += 1 + + return logical_ids diff --git a/api/src/app/utils/name_utils.py b/api/src/app/utils/name_utils.py index 3cd5f604..d48fc5e1 100644 --- a/api/src/app/utils/name_utils.py +++ b/api/src/app/utils/name_utils.py @@ -1,3 +1,30 @@ +import re + + def normalize_name(name: str) -> str: - """Remove problematic characters from user-supplied names.""" - return name.strip().replace(" ", "") + """Remove problematic characters from user-supplied names for cloud deployments while maintaining readability. + + Args: + name: Name to normalize. + + Returns: + Name normalized to a safe kebab case version. + + """ + normalized_name = name.lower() + + # Strip out disallowed characters + normalized_name = re.sub(r"[^a-z0-9\-]", "-", normalized_name) + + # Remove extra hyphens + normalized_name = re.sub(r"-+", "-", normalized_name) + normalized_name = normalized_name.strip("-") + + if not normalized_name: + msg = ( + f"Name is empty after normalization. Original name: '{name}'. " + "A valid name must contain at least one alphanumeric character after normalization." + ) + raise ValueError(msg) + + return normalized_name diff --git a/api/tests/common/api/v1/config.py b/api/tests/common/api/v1/config.py index e6e9821f..5ad1bb16 100644 --- a/api/tests/common/api/v1/config.py +++ b/api/tests/common/api/v1/config.py @@ -130,6 +130,21 @@ } ], }, + # This is here to test that we handle + # normalized name collisions + { + "cidr": "10.0.3.0/24", + "name": "dev_subnet_app", + "hosts": [ + { + "hostname": "dev-app-01", + "os": "debian_11", + "spec": "small", + "size": 30, + "tags": ["app", "linux"], + } + ], + }, ], }, { diff --git a/api/tests/deploy_test_utils.py b/api/tests/deploy_test_utils.py index 451da493..089d6522 100644 --- a/api/tests/deploy_test_utils.py +++ b/api/tests/deploy_test_utils.py @@ -1,11 +1,20 @@ +import asyncio import logging import os from contextlib import asynccontextmanager from enum import Enum +from socket import error as SocketError # noqa: N812 from typing import AsyncGenerator +import paramiko from httpx import AsyncClient +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed +from src.app.enums.operating_systems import ( + AWS_SSH_USERNAME_MAP, + AZURE_SSH_USERNAME_MAP, + OpenLabsOS, +) from src.app.enums.providers import OpenLabsProvider logger = logging.getLogger(__name__) @@ -105,3 +114,75 @@ def range_test_id(range_type: RangeType) -> str | None: type(range_type), ) return None + + +RETRYABLE_EXCEPTIONS = ( + paramiko.SSHException, + TimeoutError, + SocketError, +) + + +@retry( + stop=stop_after_attempt(3), + wait=wait_fixed(5), # Wait 5 seconds between retries + retry=retry_if_exception_type(RETRYABLE_EXCEPTIONS), + reraise=True, # Reraise the last exception if all retries fail +) +async def ssh_connect_to_host( + hostname: str, + username: str, + private_key: paramiko.PKey, + jumpbox_transport: paramiko.Transport | None = None, + jumpbox_public_ip: str | None = None, +) -> paramiko.SSHClient: + """Establish a SSH connection to a host with retry logic. + + This function can connect directly to a host or tunnel through a jumpbox + by providing an active `paramiko.Transport`. + + Args: + hostname: The IP address or hostname of the target machine. + username: The SSH username for the target machine. + private_key: The paramiko private key object for authentication. + jumpbox_transport: Optional transport from an existing jumpbox client for tunneling. + jumpbox_public_ip: The jumpbox's public IP, required for tunneling. + + Returns: + A connected `paramiko.SSHClient` instance. + + """ + target_client = paramiko.SSHClient() + target_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) # noqa: S507 + + sock = None + if jumpbox_transport: + if not jumpbox_public_ip: + msg = "jumpbox_public_ip is required when using a jumpbox_transport." + raise ValueError(msg) + + # Create a tunnel ("direct-tcpip" channel) through the jumpbox + src_addr = (str(jumpbox_public_ip), 22) + dest_addr = (str(hostname), 22) + sock = jumpbox_transport.open_channel("direct-tcpip", dest_addr, src_addr) + + await asyncio.to_thread( + target_client.connect, + hostname=hostname, + username=username, + pkey=private_key, + sock=sock, + timeout=15, # Extra timeout to improve reliability + ) + return target_client + + +def get_ssh_username(provider: OpenLabsProvider, os: OpenLabsOS) -> str: + """Get the default SSH username for a given provider and OS.""" + if provider == OpenLabsProvider.AWS: + return AWS_SSH_USERNAME_MAP[os] + if provider == OpenLabsProvider.AZURE: + return AZURE_SSH_USERNAME_MAP[os] + + msg = f"Unsupported provider-OS combination: {provider}-{os}" + raise ValueError(msg) diff --git a/api/tests/integration/api/v1/test_ranges.py b/api/tests/integration/api/v1/test_ranges.py index f656b34f..8c5359bb 100644 --- a/api/tests/integration/api/v1/test_ranges.py +++ b/api/tests/integration/api/v1/test_ranges.py @@ -1,25 +1,26 @@ import asyncio import io +import logging import paramiko import pytest from httpx import AsyncClient -from src.app.enums.operating_systems import ( - AWS_SSH_USERNAME_MAP, - AZURE_SSH_USERNAME_MAP, - OpenLabsOS, -) -from src.app.enums.providers import OpenLabsProvider +from src.app.enums.operating_systems import OpenLabsOS +from src.app.schemas.host_schemas import DeployedHostSchema from src.app.schemas.range_schemas import DeployedRangeSchema from tests.api_test_utils import get_range, get_range_key, login_user from tests.deploy_test_utils import ( RangeType, + get_ssh_username, provider_test_id, range_test_id, + ssh_connect_to_host, ) from tests.integration.api.v1.config import PROVIDER_PARAMS, RANGE_TYPE_PARAMS +logger = logging.getLogger(__name__) + @pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize( @@ -99,40 +100,19 @@ async def test_jumpbox_direct_connection( deployed_range = provider_deployed_ranges_for_provider[range_type] range_info, email, password = deployed_range - assert await login_user( - integration_client, email, password - ), "Failed to login to the deployed range account." - + assert await login_user(integration_client, email, password) private_key_str = await get_range_key(integration_client, range_info.id) - assert ( - private_key_str - ), f"Could not retrieve key for range with ID: {range_info.id}" + assert private_key_str, "Could not retrieve key for range." + + private_key = paramiko.RSAKey.from_private_key(io.StringIO(private_key_str)) + jumpbox_username = get_ssh_username(range_info.provider, OpenLabsOS.UBUNTU_22) ssh_client = None try: - private_key_file = io.StringIO(private_key_str) - private_key = paramiko.RSAKey.from_private_key(private_key_file) - - ssh_client = paramiko.SSHClient() - ssh_client.set_missing_host_key_policy( - paramiko.AutoAddPolicy() # noqa: S507 - ) - - # Connect directly to the jumpbox using its public IP - # Jumpbox typically uses Ubuntu, so get the Ubuntu username for the provider - if range_info.provider == OpenLabsProvider.AWS: - jumpbox_username = AWS_SSH_USERNAME_MAP[OpenLabsOS.UBUNTU_22] - elif range_info.provider == OpenLabsProvider.AZURE: - jumpbox_username = AZURE_SSH_USERNAME_MAP[OpenLabsOS.UBUNTU_22] - else: - pytest.fail(f"Unsupported provider: {range_info.provider}") - - await asyncio.to_thread( - ssh_client.connect, + ssh_client = await ssh_connect_to_host( hostname=str(range_info.jumpbox_public_ip), username=jumpbox_username, - pkey=private_key, - timeout=10, + private_key=private_key, ) # Validate command exexcution with 'id' command @@ -188,94 +168,42 @@ async def test_jumpbox_to_vm_connections( deployed_range = provider_deployed_ranges_for_provider[range_type] range_info, email, password = deployed_range - assert await login_user( - integration_client, email, password - ), "Failed to login to the deployed range account." - + assert await login_user(integration_client, email, password) private_key_str = await get_range_key(integration_client, range_info.id) - assert ( - private_key_str - ), f"Could not retrieve key for range with ID: {range_info.id}" - - # Extract all private IPs and their OS from range_info - host_info: list[dict[str, str]] = [] - for vpc in range_info.vpcs: - for subnet in vpc.subnets: - for host in subnet.hosts: - host_info.append( - { - "ip": str(host.ip_address), - "os": host.os.value, - "hostname": host.hostname, - } - ) + assert private_key_str, "Could not retrieve key for range." - ssh_client = None - try: - private_key_file = io.StringIO(private_key_str) - private_key = paramiko.RSAKey.from_private_key(private_key_file) + private_key = paramiko.RSAKey.from_private_key(io.StringIO(private_key_str)) + jumpbox_username = get_ssh_username(range_info.provider, OpenLabsOS.UBUNTU_22) - ssh_client = paramiko.SSHClient() - ssh_client.set_missing_host_key_policy( - paramiko.AutoAddPolicy() # noqa: S507 - ) + hosts: list[DeployedHostSchema] = [ + host + for vpc in range_info.vpcs + for subnet in vpc.subnets + for host in subnet.hosts + ] - # Connect to the jumpbox using its public IP - # Jumpbox typically uses Ubuntu, so get the Ubuntu username for the provider - if range_info.provider == OpenLabsProvider.AWS: - jumpbox_username = AWS_SSH_USERNAME_MAP[OpenLabsOS.UBUNTU_22] - elif range_info.provider == OpenLabsProvider.AZURE: - jumpbox_username = AZURE_SSH_USERNAME_MAP[OpenLabsOS.UBUNTU_22] - else: - pytest.fail(f"Unsupported provider: {range_info.provider}") - - await asyncio.to_thread( - ssh_client.connect, + ssh_client = None + try: + ssh_client = await ssh_connect_to_host( hostname=str(range_info.jumpbox_public_ip), username=jumpbox_username, - pkey=private_key, - timeout=10, + private_key=private_key, ) # Get jumpbox transport for tunneling jumpbox_transport = ssh_client.get_transport() assert jumpbox_transport is not None, "Failed to get SSH transport" - for host_data in host_info: - ip = host_data["ip"] - os_name = host_data["os"] - hostname = host_data["hostname"] - + for host in hosts: target_client = None try: - # Create a tunnel channel through the jumpbox - src_addr = (str(range_info.jumpbox_public_ip), 22) - dest_addr = (ip, 22) - jumpbox_channel = jumpbox_transport.open_channel( - "direct-tcpip", dest_addr, src_addr - ) - - target_client = paramiko.SSHClient() - target_client.set_missing_host_key_policy( - paramiko.AutoAddPolicy() # noqa: S507 - ) - - # Get the appropriate SSH username for this OS based on provider - os_enum = OpenLabsOS(os_name) - if range_info.provider == OpenLabsProvider.AWS: - username = AWS_SSH_USERNAME_MAP[os_enum] - elif range_info.provider == OpenLabsProvider.AZURE: - username = AZURE_SSH_USERNAME_MAP[os_enum] - else: - pytest.fail(f"Unsupported provider: {range_info.provider}") - - await asyncio.to_thread( - target_client.connect, - hostname=ip, + username = get_ssh_username(range_info.provider, host.os) + target_client = await ssh_connect_to_host( + hostname=str(host.ip_address), username=username, - pkey=private_key, - sock=jumpbox_channel, - timeout=10, + private_key=private_key, + jumpbox_transport=jumpbox_transport, + jumpbox_public_ip=str(range_info.jumpbox_public_ip), ) # Validate command execution with 'id' command @@ -290,13 +218,16 @@ async def test_jumpbox_to_vm_connections( ), f"Expected username '{username}' not found in output: {command_output}" assert ( not error_output - ), f"Error executing 'id' command on {hostname} ({ip}): {error_output}" - print( - f"Successfully verified user identity on {hostname} ({ip}) with username '{username}'" + ), f"Error executing 'id' command on {host.hostname} ({host.ip_address}): {error_output}" + logger.info( + "Successfully verified user identity on %s (%s) with username '%s'", + host.hostname, + host.ip_address, + username, ) except Exception as e: pytest.fail( - f"Exception connecting to {hostname} ({ip}) with username '{username}': {e}" + f"Exception connecting to {host.hostname} ({host.ip_address}) with username '{username}': {e}" ) finally: if target_client: diff --git a/api/tests/unit/utils/test_cdktf_utils.py b/api/tests/unit/utils/test_cdktf_utils.py new file mode 100644 index 00000000..278b01b3 --- /dev/null +++ b/api/tests/unit/utils/test_cdktf_utils.py @@ -0,0 +1,86 @@ +import pytest + +from src.app.utils.cdktf_utils import gen_resource_logical_ids + + +@pytest.mark.parametrize( + "input_names, expected_ids", + [ + # Basic cases + pytest.param([], {}, id="empty_list"), + pytest.param(["MyServer"], {"MyServer": "myserver"}, id="single_name"), + pytest.param( + ["Server1", "Server2"], + {"Server1": "server1", "Server2": "server2"}, + id="two_unique_names", + ), + pytest.param( + ["Web Server", "Database", "Auth Service"], + { + "Auth Service": "auth-service", + "Database": "database", + "Web Server": "web-server", + }, + id="multiple_unique_names_with_spaces", + ), + # Collision handling + pytest.param( + ["Web Server", "web-server"], + {"Web Server": "web-server", "web-server": "web-server-1"}, + id="two_collide_after_norm", + ), + pytest.param( + ["Server", "server", "sErVeR"], + {"Server": "server", "sErVeR": "server-1", "server": "server-2"}, + id="three_collide_after_norm", + ), + pytest.param( + ["App-Service", "App Service", "app_service"], + { + "App Service": "app-service", + "App-Service": "app-service-1", + "app_service": "app-service-2", + }, + id="multiple_collision_types", + ), + # Sorted order affects suffix assignment + pytest.param( + ["A", "a", "A-", "-a"], + {"A": "a-1", "-a": "a", "A-": "a-2", "a": "a-3"}, + id="complex_collision_leading_trailing_symbols", + ), + # Names that normalize to the same base but are distinct in input + pytest.param( + ["user-profile", "User Profile"], + {"user-profile": "user-profile-1", "User Profile": "user-profile"}, + id="distinct_input_same_norm", + ), + pytest.param( + ["my_app", "my-app"], + {"my_app": "my-app-1", "my-app": "my-app"}, + id="underscore_vs_hyphen", + ), + ], +) +def test_gen_resource_logical_ids_success( + input_names: list[str], expected_ids: dict[str, str] +) -> None: + """Test various valid inputs for gen_resource_logical_ids to validate ID generation and collision handling.""" + result = gen_resource_logical_ids(input_names) + assert result == expected_ids + + +@pytest.mark.parametrize( + "input_names", + [ + pytest.param(["Duplicate", "Duplicate"], id="exact_duplicate_names"), + pytest.param(["Test", "test", "Test"], id="exact_duplicate_mixed_case"), + pytest.param(["A", "B", "A", "C", "B"], id="multiple_exact_duplicates"), + ], +) +def test_gen_resource_logical_ids_duplicate_error( + input_names: list[str], +) -> None: + """Tests that gen_resource_logical_ids raises a ValueError when the input list contains exact duplicate names.""" + with pytest.raises(ValueError, match="duplicate"): + gen_resource_logical_ids(input_names) diff --git a/api/tests/unit/utils/test_name_utils.py b/api/tests/unit/utils/test_name_utils.py new file mode 100644 index 00000000..2d801afb --- /dev/null +++ b/api/tests/unit/utils/test_name_utils.py @@ -0,0 +1,79 @@ +import pytest + +from src.app.utils.name_utils import normalize_name + + +@pytest.mark.parametrize( + "input_name, expected_output", + [ + # Basic cases + ("My Awesome Lab", "my-awesome-lab"), + ("another_test_name", "another-test-name"), + ("simple-name", "simple-name"), # Already normalized + ("TestName", "testname"), # Mixed case + ("test123name", "test123name"), # Numbers + ("test-123-name", "test-123-name"), # Numbers with hyphens + # Special characters + ( + "Name!@#$%^&*()_+={}|[]\\:;\"'<>,.?/~`", + "name", + ), + ("Name with.dots", "name-with-dots"), + ("Name with/slashes", "name-with-slashes"), + ("Name with spaces and -hyphens", "name-with-spaces-and-hyphens"), + # Multiple consecutive problematic characters + ("Name--with---multiple____hyphens", "name-with-multiple-hyphens"), + ("Name with many spaces", "name-with-many-spaces"), + ("Name_-_With_Mixed_Separators", "name-with-mixed-separators"), + ("a---b", "a-b"), + ("a___b", "a-b"), + ("a...b", "a-b"), + ("a@@@b", "a-b"), + # Leading/trailing problematic characters + ("-Name-Starts-With-Hyphen", "name-starts-with-hyphen"), + ("Name-Ends-With-Hyphen-", "name-ends-with-hyphen"), + ( + " Name With Leading And Trailing Spaces ", + "name-with-leading-and-trailing-spaces", + ), + ("!@#NameWithSymbols!@#", "namewithsymbols"), + ("---Name---", "name"), + ("___Name___", "name"), + ("...Name...", "name"), + # Names that are already clean + ("already-kebab-case", "already-kebab-case"), + ("lowercasealphanumeric", "lowercasealphanumeric"), + ], +) +def test_normalize_name_valid_cases(input_name: str, expected_output: str) -> None: + """Test various valid input names to ensure they are normalized correctly to safe kebab-case versions.""" + assert normalize_name(input_name) == expected_output + + +@pytest.mark.parametrize( + "input_name", + [ + "", # Empty string + " ", # Only spaces + "---", # Only hyphens + "!!!", # Only special characters + " @#$ ", # Mixed problematic characters + " _ ", # Only underscores + " . ", # Only dots + ], + ids=[ + "empty_string", + "only_spaces", + "only_hyphens", + "only_special_chars", + "mixed_problematic_chars", + "only_underscores", + "only_dots", + ], +) +def test_normalize_name_empty_raises_value_error( + input_name: str, +) -> None: + """Test input names that should result in an empty string after normalization that force a ValueError to be raised.""" + with pytest.raises(ValueError, match="empty"): + normalize_name(input_name)