diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d689fff..020af2e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -46,6 +46,8 @@ jobs: MODEL_S3_BUCKET: ${{ secrets.MODEL_S3_BUCKET }} HLS_TRANSFORM_DIRECTORY: ${{ secrets.HLS_TRANSFORM_DIRECTORY }} FPGA_DEV_AMI: ${{ secrets.FPGA_DEV_AMI }} + GPU_DEV_AMI: ${{ secrets.GPU_DEV_AMI }} + CPU_DEV_AMI: ${{ secrets.CPU_DEV_AMI }} EC2_IAM_ROLE: ${{ secrets.EC2_IAM_ROLE }} EC2_KEY_PAIR: ${{ secrets.EC2_KEY_PAIR }} EC2_SECURITY_GROUP: ${{ secrets.EC2_SECURITY_GROUP }} diff --git a/app/api/v1/endpoints/machine_endpoints.py b/app/api/v1/endpoints/machine_endpoints.py index 369fe87..eb696b4 100644 --- a/app/api/v1/endpoints/machine_endpoints.py +++ b/app/api/v1/endpoints/machine_endpoints.py @@ -3,6 +3,8 @@ """ import logging +import httpx +import requests from typing import Annotated, List from fastapi import APIRouter, Depends, status @@ -15,9 +17,13 @@ MachineCreate, ModelInferenceRequest, ModelInferenceResponse, + ModelSelectionRequest, ) from models.user import UserResponse -from scripts.ec2_setup import generate_hlstransform_setup_script +from scripts.ec2_setup import ( + generate_hlstransform_setup_script, + generate_ollama_setup_script, +) from services.ec2_service import EC2Service from services.user_service import get_current_active_user from utils.exceptions import ( @@ -29,15 +35,19 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -router = APIRouter(tags=["machine"]) +router = APIRouter() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth") +gpu_tag = "machine / gpu" +fpga_tag = "machine / fpga" +cpu_tag = "machine / cpu" + def get_ec2_service(): return EC2Service() -@router.get("/machines/unassigned", response_model=List[Machine]) +@router.get("/machines/unassigned", response_model=List[Machine], tags=["machine"]) async def list_unassigned_machines( current_user: Annotated[UserResponse, Depends(get_current_active_user)], ec2_service: EC2Service = Depends(get_ec2_service), @@ -53,7 +63,7 @@ async def list_unassigned_machines( ) -@router.get("/machines", response_model=List[Machine]) +@router.get("/machines", response_model=List[Machine], tags=["machine"]) async def list_user_machines( current_user: Annotated[UserResponse, Depends(get_current_active_user)], ec2_service: EC2Service = Depends(get_ec2_service), @@ -69,7 +79,7 @@ async def list_user_machines( ) -@router.get("/machines/{machine_id}", response_model=Machine) +@router.get("/machine/{machine_id}", response_model=Machine, tags=["machine"]) async def get_machine_details( current_user: Annotated[UserResponse, Depends(get_current_active_user)], machine_id: str, @@ -89,7 +99,9 @@ async def get_machine_details( ) -@router.post("/machines/start/{machine_id}", response_model=MessageResponse) +@router.post( + "/machine/start/{machine_id}", response_model=MessageResponse, tags=["machine"] +) async def start_machine( current_user: Annotated[UserResponse, Depends(get_current_active_user)], machine_id: str, @@ -109,7 +121,9 @@ async def start_machine( ) -@router.post("/machines/stop/{machine_id}", response_model=MessageResponse) +@router.post( + "/machine/stop/{machine_id}", response_model=MessageResponse, tags=["machine"] +) async def stop_machine( current_user: Annotated[UserResponse, Depends(get_current_active_user)], machine_id: str, @@ -129,8 +143,13 @@ async def stop_machine( ) -@router.post("/machine", response_model=Machine, status_code=status.HTTP_201_CREATED) -async def create_machine( +@router.post( + "/machine/fpga", + response_model=Machine, + status_code=status.HTTP_201_CREATED, + tags=[fpga_tag], +) +async def create_fpga_machine( current_user: Annotated[UserResponse, Depends(get_current_active_user)], machine_create: MachineCreate, ec2_service: EC2Service = Depends(get_ec2_service), @@ -161,7 +180,42 @@ async def create_machine( ) -@router.delete("/machines/{machine_id}", response_model=MessageResponse) +@router.post( + "/machine/gpu", + response_model=Machine, + status_code=status.HTTP_201_CREATED, + tags=[gpu_tag], +) +async def create_gpu_machine( + current_user: Annotated[UserResponse, Depends(get_current_active_user)], + machine_create: MachineCreate, + ec2_service: EC2Service = Depends(get_ec2_service), +) -> Machine: + try: + user_data = generate_ollama_setup_script(user_name=current_user.user_name) + + logger.debug(f"Generated user data: {user_data}") + + response = ec2_service.create_gpu_machine( + user_id=current_user.user_id, + instance_name=machine_create.machine_name, + instance_type=machine_create.machine_type, + user_data=user_data, + ) + + return response + except Exception as e: + logger.error(f"An error occurred: {e}") + raise EC2Error( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"An unexpected error occurred while creating the machine: {e}", + error_code="INTERNAL_SERVER_ERROR", + ) + + +@router.delete( + "/machine/{machine_id}", response_model=MessageResponse, tags=["machine"] +) async def terminate_machine( current_user: Annotated[UserResponse, Depends(get_current_active_user)], machine_id: str, @@ -181,7 +235,11 @@ async def terminate_machine( ) -@router.post("/machines/{machine_id}/inference", response_model=ModelInferenceResponse) +@router.post( + "/machines/{machine_id}/inference", + response_model=ModelInferenceResponse, + tags=[fpga_tag], +) async def run_model_inference( machine_id: str, request: ModelInferenceRequest, @@ -207,3 +265,278 @@ async def run_model_inference( detail="A fatal server error occurred while running model inference", error_code="INTERNAL_SERVER_ERROR", ) + + +@router.post("/machine/gpu/pull_model", tags=[gpu_tag]) +async def pull_gpu_model( + current_user: Annotated[UserResponse, Depends(get_current_active_user)], + selection_request: ModelSelectionRequest, + ec2_service: Annotated[EC2Service, Depends(get_ec2_service)], +): + try: + machine_id = selection_request.machine_id + model_name = selection_request.model_name + + machine_public_ip = ec2_service.get_instance_public_ip(machine_id) + url = f"http://{machine_public_ip}:11434/api/pull" + + model_request = { + "model": model_name, + "stream": False, + } + + response = requests.post( + url=url, + json=model_request, + ) + + if response.status_code == 200: + response_data = response.json() + response_data["message"] = f"Model {model_name} pulled to instance" + return response_data + elif response.status_code == 500: + error_message = response.json().get("error", "Unknown error") + if error_message == "pull model manifest: file does not exist": + logger.error(f"Error: {error_message}") + return {"status": "failed", "error": error_message} + else: + response.raise_for_status() + else: + response.raise_for_status() + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred: {e}") + return {"error": str(e)} + except Exception as e: + logger.error(f"An error occurred: {e}") + return {"error": str(e)} + + +@router.delete("/machine/gpu/model", tags=[gpu_tag]) +async def delete_gpu_model( + current_user: Annotated[UserResponse, Depends(get_current_active_user)], + selection_request: ModelSelectionRequest, + ec2_service: Annotated[EC2Service, Depends(get_ec2_service)], +): + try: + machine_id = selection_request.machine_id + model_name = selection_request.model_name + + machine_public_ip = ec2_service.get_instance_public_ip(machine_id) + url = f"http://{machine_public_ip}:11434/api/delete" + + model_delete_request = { + "model": model_name, + } + + response = requests.delete( + url=url, + json=model_delete_request, + ) + + if response.status_code == 200: + return { + "status": "success", + "message": f"Model {model_name} deleted from instance", + } + elif response.status_code == 404: + error_message = "Model not found on instance" + logger.error(f"Error: {error_message}") + return {"status": "failed", "error": error_message} + else: + response.raise_for_status() + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred: {e}") + return {"error": str(e)} + except Exception as e: + logger.error(f"An error occurred: {e}") + return {"error": str(e)} + + +@router.get("/machine/gpu/{machine_id}/inference_url", tags=[gpu_tag]) +async def get_gpu_inference_url( + machine_id: str, + current_user: Annotated[UserResponse, Depends(get_current_active_user)], + ec2_service: EC2Service = Depends(get_ec2_service), +): + try: + isOwner = ec2_service.is_user_owner_of_instance( + user_id=current_user.user_id, instance_id=machine_id + ) + if not isOwner: + raise EC2Error( + status_code=status.HTTP_403_FORBIDDEN, + detail="User not the owner of this machine", + error_code="FORBIDDEN", + ) + public_ip = ec2_service.get_instance_public_ip(machine_id) + ollama_url = f"http://{public_ip}:11434/api/generate" + + return {"inference_url": ollama_url} + except EC2Error as e: + logger.error(f"An error occurred: {e}") + raise EC2Error( + status_code=e.status_code, + detail=f"An EC2 Error occurred: {e.detail}", + error_code=e.error_code, + ) + except Exception as e: + logger.error(f"An internal server error occurred: {e}") + raise EC2Error( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"A server error occurred while getting the inference URL: {e}", + error_code="INTERNAL_SERVER_ERROR", + ) + + +@router.post( + "/machine/cpu", + response_model=Machine, + status_code=status.HTTP_201_CREATED, + tags=[cpu_tag], +) +async def create_cpu_machine( + current_user: Annotated[UserResponse, Depends(get_current_active_user)], + machine_create: MachineCreate, + ec2_service: EC2Service = Depends(get_ec2_service), +) -> Machine: + try: + user_data = generate_ollama_setup_script(user_name=current_user.user_name) + + logger.debug(f"Generated user data: {user_data}") + + response = ec2_service.create_gpu_machine( + user_id=current_user.user_id, + instance_name=machine_create.machine_name, + instance_type=machine_create.machine_type, + user_data=user_data, + ) + + return response + except Exception as e: + logger.error(f"An error occurred: {e}") + raise EC2Error( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"An unexpected error occurred while creating the machine: {e}", + error_code="INTERNAL_SERVER_ERROR", + ) + + +@router.post("/machine/cpu/pull_model", tags=[cpu_tag]) +async def pull_cpu_model( + current_user: Annotated[UserResponse, Depends(get_current_active_user)], + selection_request: ModelSelectionRequest, + ec2_service: Annotated[EC2Service, Depends(get_ec2_service)], +): + try: + machine_id = selection_request.machine_id + model_name = selection_request.model_name + + machine_public_ip = ec2_service.get_instance_public_ip(machine_id) + url = f"http://{machine_public_ip}:11434/api/pull" + + model_request = { + "model": model_name, + "stream": False, + } + + response = requests.post( + url=url, + json=model_request, + ) + + if response.status_code == 200: + response_data = response.json() + response_data["message"] = f"Model {model_name} pulled to instance" + return response_data + elif response.status_code == 500: + error_message = response.json().get("error", "Unknown error") + if error_message == "pull model manifest: file does not exist": + logger.error(f"Error: {error_message}") + return {"status": "failed", "error": error_message} + else: + response.raise_for_status() + else: + response.raise_for_status() + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred: {e}") + return {"error": str(e)} + except Exception as e: + logger.error(f"An error occurred: {e}") + return {"error": str(e)} + + +@router.delete("/machine/cpu/model", tags=[cpu_tag]) +async def delete_cpu_model( + current_user: Annotated[UserResponse, Depends(get_current_active_user)], + selection_request: ModelSelectionRequest, + ec2_service: Annotated[EC2Service, Depends(get_ec2_service)], +): + try: + machine_id = selection_request.machine_id + model_name = selection_request.model_name + + machine_public_ip = ec2_service.get_instance_public_ip(machine_id) + url = f"http://{machine_public_ip}:11434/api/delete" + + model_delete_request = { + "model": model_name, + } + + response = requests.delete( + url=url, + json=model_delete_request, + ) + + if response.status_code == 200: + return { + "status": "success", + "message": f"Model {model_name} deleted from instance", + } + elif response.status_code == 404: + error_message = "Model not found on instance" + logger.error(f"Error: {error_message}") + return {"status": "failed", "error": error_message} + else: + response.raise_for_status() + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred: {e}") + return {"error": str(e)} + except Exception as e: + logger.error(f"An error occurred: {e}") + return {"error": str(e)} + + +@router.get("/machine/cpu/{machine_id}/inference_url", tags=[cpu_tag]) +async def get_cpu_inference_url( + machine_id: str, + current_user: Annotated[UserResponse, Depends(get_current_active_user)], + ec2_service: EC2Service = Depends(get_ec2_service), +): + try: + isOwner = ec2_service.is_user_owner_of_instance( + user_id=current_user.user_id, instance_id=machine_id + ) + if not isOwner: + raise EC2Error( + status_code=status.HTTP_403_FORBIDDEN, + detail="User not the owner of this machine", + error_code="FORBIDDEN", + ) + public_ip = ec2_service.get_instance_public_ip(machine_id) + ollama_url = f"http://{public_ip}:11434/api/generate" + + return {"inference_url": ollama_url} + except EC2Error as e: + logger.error(f"An error occurred: {e}") + raise EC2Error( + status_code=e.status_code, + detail=f"An EC2 Error occurred: {e.detail}", + error_code=e.error_code, + ) + except Exception as e: + logger.error(f"An internal server error occurred: {e}") + raise EC2Error( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"A server error occurred while getting the inference URL: {e}", + error_code="INTERNAL_SERVER_ERROR", + ) diff --git a/app/core/config.py b/app/core/config.py index 7efc9b1..4407546 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -18,6 +18,9 @@ class Settings(BaseSettings): # Machine Management FPGA_DEV_AMI: str + GPU_DEV_AMI: str + CPU_DEV_AMI: str + EC2_IAM_ROLE: str EC2_KEY_PAIR: str EC2_SECURITY_GROUP: str diff --git a/app/models/machine.py b/app/models/machine.py index 108e5db..1694a83 100644 --- a/app/models/machine.py +++ b/app/models/machine.py @@ -33,3 +33,12 @@ class ModelInferenceRequest(BaseModel): class ModelInferenceResponse(BaseModel): output: str + + +class ModelSelectionRequest(BaseModel): + machine_id: str = Field(..., description="ID of the machine to select.") + model_name: str = Field(..., description="Name of the model to select.") + + class Config: + populate_by_name = True + protected_namespaces = () diff --git a/app/scripts/ec2_setup.py b/app/scripts/ec2_setup.py index d7e536b..f368f5a 100644 --- a/app/scripts/ec2_setup.py +++ b/app/scripts/ec2_setup.py @@ -70,3 +70,79 @@ def generate_hlstransform_setup_script( echo "Setup complete. You can now run your application with './llama2' as the '{user_name}' user." """ + + +def generate_ollama_setup_script(user_name: str) -> str: + return f"""#!/bin/bash +# Log output to file for debugging purposes +exec > /var/log/user-data.log 2>&1 +set -x + +# Update package list and install required packages +yum update -y # Use 'apt-get' if using Ubuntu or Debian +yum install -y aws-cli git # Install AWS CLI and Git, required for the script + +# Install SSM Agent +echo "→ Installing SSM Agent..." +yum install -y https://s3.{settings.AWS_DEFAULT_REGION}.amazonaws.com/amazon-ssm-{settings.AWS_DEFAULT_REGION}/latest/linux_amd64/amazon-ssm-agent.rpm + +# Start SSM Agent +systemctl enable amazon-ssm-agent +systemctl start amazon-ssm-agent + +# Create '{user_name}' user with a home directory +echo "→ Creating '{user_name}' user with a home directory..." +useradd -m -d /home/{user_name} {user_name} +chown -R {user_name}:{user_name} /home/{user_name} +chmod 755 /home/{user_name} + +# Install Ollama +echo "→ Installing Ollama..." +curl -fsSL https://ollama.com/install.sh | sh + +# Enable and start Ollama service +systemctl enable ollama +systemctl start ollama + +# 1. Check status of ollama port (11434) +echo "→ Checking status of ollama port..." +netstat -a | grep 11434 + +# 2. Update the ollama.service +echo "→ Updating ollama.service..." +if sudo systemctl status ollama.service > /dev/null 2>&1; then + # Create the override directory + echo "→→ Create override dir..." + sudo mkdir -p /etc/systemd/system/ollama.service.d + + # Write the override file + echo "→→ Write override file..." + echo -e "[Service]\nEnvironment=\"OLLAMA_HOST=0.0.0.0\"" | sudo tee /etc/systemd/system/ollama.service.d/override.conf > /dev/null + +else + echo "→ ollama.service not found! Please ensure it is installed and try again." + exit 1 +fi + +# 3. Reload the systemd daemon +echo "→ Reloading systemd daemon..." +sudo systemctl daemon-reload + +# 4. Restart the ollama service +echo "→ Restarting ollama service..." +sudo systemctl restart ollama + +# Confirm if the service restarted successfully +if sudo systemctl is-active --quiet ollama; then + echo "→ ollama service restarted successfully!" +else + echo "→ Failed to restart ollama service." + exit 1 +fi + +# 5. Check status of ollama port (11434) after restart +echo "→ Checking status of ollama port after restart..." +netstat -a | grep 11434 + +echo "→ Ollama installed. Ready for use." +""" diff --git a/app/services/ec2_service.py b/app/services/ec2_service.py index 0d4c6dd..b982212 100644 --- a/app/services/ec2_service.py +++ b/app/services/ec2_service.py @@ -222,6 +222,110 @@ def create_machine( error_code="AWS_CLIENT_ERROR", ) + def create_gpu_machine( + self, + instance_type: str, + instance_name: str, + user_id: str, + user_data: str, + ) -> Machine: + try: + ami_id = settings.GPU_DEV_AMI + role_arn = settings.EC2_IAM_ROLE + security_group_id = settings.EC2_SECURITY_GROUP + key_pair_name = settings.EC2_KEY_PAIR + + response = self.ec2.run_instances( + ImageId=ami_id, + InstanceType=instance_type, + KeyName=key_pair_name, + MinCount=1, + MaxCount=1, + IamInstanceProfile={"Arn": role_arn}, + SecurityGroupIds=[security_group_id], + UserData=user_data, + TagSpecifications=[ + { + "ResourceType": "instance", + "Tags": [ + {"Key": "Name", "Value": instance_name}, + {"Key": "user_id", "Value": user_id}, + {"Key": "assigned", "Value": "true"}, + ], + }, + ], + ) + + instance = response["Instances"][0] + instance_id = instance["InstanceId"] + + # Wait for the instance to be in a running state + waiter = self.ec2.get_waiter("instance_running") + waiter.wait(InstanceIds=[instance_id]) + + instance_details = self.get_machine_details(instance_id) + + return instance_details + + except ClientError as e: + raise EC2Error( + status_code=400, + detail=f"Failed to create instance: {str(e)}", + error_code="AWS_CLIENT_ERROR", + ) + + def create_cpu_machine( + self, + instance_type: str, + instance_name: str, + user_id: str, + user_data: str, + ) -> Machine: + try: + ami_id = settings.CPU_DEV_AMI + role_arn = settings.EC2_IAM_ROLE + security_group_id = settings.EC2_SECURITY_GROUP + key_pair_name = settings.EC2_KEY_PAIR + + response = self.ec2.run_instances( + ImageId=ami_id, + InstanceType=instance_type, + KeyName=key_pair_name, + MinCount=1, + MaxCount=1, + IamInstanceProfile={"Arn": role_arn}, + SecurityGroupIds=[security_group_id], + UserData=user_data, + TagSpecifications=[ + { + "ResourceType": "instance", + "Tags": [ + {"Key": "Name", "Value": instance_name}, + {"Key": "user_id", "Value": user_id}, + {"Key": "assigned", "Value": "true"}, + ], + }, + ], + ) + + instance = response["Instances"][0] + instance_id = instance["InstanceId"] + + # Wait for the instance to be in a running state + waiter = self.ec2.get_waiter("instance_running") + waiter.wait(InstanceIds=[instance_id]) + + instance_details = self.get_machine_details(instance_id) + + return instance_details + + except ClientError as e: + raise EC2Error( + status_code=400, + detail=f"Failed to create instance: {str(e)}", + error_code="AWS_CLIENT_ERROR", + ) + def terminate_machine(self, machine_id: str) -> bool: try: response = self.ec2.terminate_instances(InstanceIds=[machine_id]) @@ -362,3 +466,33 @@ def get_command_output(self, instance_id: str, command_id: str) -> str: raise EC2Error( status_code=400, detail=detail_msg, error_code=custom_error_code ) + + def get_instance_public_ip(self, instance_id: str) -> str: + try: + response = self.ec2.describe_instances(InstanceIds=[instance_id]) + return response["Reservations"][0]["Instances"][0]["PublicIpAddress"] + except ClientError as e: + raise EC2Error( + status_code=400, + detail=str(e), + error_code="AWS_CLIENT_ERROR", + ) + + def is_user_owner_of_instance(self, user_id: str, instance_id: str) -> bool: + try: + response = self.ec2.describe_instances(InstanceIds=[instance_id]) + instance = response["Reservations"][0]["Instances"][0] + tags = instance.get("Tags", []) + owner_tag = next((tag for tag in tags if tag["Key"] == "user_id"), None) + isOwner = owner_tag["Value"].lower() == user_id.lower() + return isOwner + except ClientError as e: + if e.response["Error"]["Code"] in [ + "InvalidInstanceID.Malformed", + "InvalidInstanceID.NotFound", + ]: + raise ValueError(f"Error: {e.response['Error']['Message']}") + else: + raise RuntimeError(f"ClientError: {str(e)}") + except Exception as e: + raise RuntimeError(f"An internal server error occurred: {str(e)}") diff --git a/requirements.txt b/requirements.txt index 8041d49..0c24bb2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,6 +49,7 @@ python-jose==3.3.0 python-multipart==0.0.9 pytz==2024.2 PyYAML==6.0.2 +requests==2.32.3 rich==13.8.1 rsa==4.7.2 ruff==0.6.4 diff --git a/test/unit/test_ec2_service.py b/test/unit/test_ec2_service.py index af8e60c..da0634a 100644 --- a/test/unit/test_ec2_service.py +++ b/test/unit/test_ec2_service.py @@ -304,3 +304,85 @@ def test_get_instance_name(self): self.ec2_service.get_instance_name(instance_without_name_or_keyname), "Unnamed Instance", ) + + def test_get_instance_public_ip_success(self): + mock_response = { + "Reservations": [{"Instances": [{"PublicIpAddress": "123.123.123.123"}]}] + } + self.mock_ec2_client.describe_instances.return_value = mock_response + + public_ip = self.ec2_service.get_instance_public_ip("i-1234567890abcdef0") + self.assertEqual(public_ip, "123.123.123.123") + + def test_get_instance_public_ip_client_error(self): + self.mock_ec2_client.describe_instances.side_effect = ClientError( + { + "Error": { + "Code": "InvalidInstanceID.NotFound", + "Message": "The instance ID 'i-1234567890abcdef0' does not exist", + } + }, + "DescribeInstances", + ) + + with self.assertRaises(EC2Error) as context: + self.ec2_service.get_instance_public_ip("i-1234567890abcdef0") + + self.assertEqual(context.exception.status_code, 400) + self.assertIn( + "The instance ID 'i-1234567890abcdef0' does not exist", + context.exception.detail, + ) + self.assertEqual(context.exception.error_code, "AWS_CLIENT_ERROR") + + def test_create_gpu_machine(self): + self.mock_ec2_client.run_instances.return_value = { + "Instances": [{"InstanceId": "i-1234567890abcdef0"}] + } + mock_waiter = self.mock_ec2_client.get_waiter.return_value + mock_waiter.wait.return_value = None + self.ec2_service.get_machine_details = MagicMock( + return_value="mock_instance_details" + ) + + response = self.ec2_service.create_gpu_machine( + instance_type="t2.micro", + instance_name="test-instance", + user_id="test-user123", + user_data="echo Hello World", + ) + + self.assertEqual(response, "mock_instance_details") + self.mock_ec2_client.run_instances.assert_called_once() + self.mock_ec2_client.get_waiter.assert_called_once_with("instance_running") + mock_waiter.wait.assert_called_once_with(InstanceIds=["i-1234567890abcdef0"]) + self.ec2_service.get_machine_details.assert_called_once_with( + "i-1234567890abcdef0" + ) + + def test_create_gpu_machine_client_error(self): + self.mock_ec2_client.run_instances.side_effect = ClientError( + error_response={ + "Error": { + "Code": "InvalidParameterValue", + "Message": "The parameter is incorrect", + } + }, + operation_name="RunInstances", + ) + + with self.assertRaises(EC2Error) as context: + self.ec2_service.create_gpu_machine( + instance_type="t2.micro", + instance_name="test-instance", + user_id="test-user123", + user_data="echo Hello World", + ) + + self.assertEqual( + context.exception.detail, + "Failed to create instance: An error occurred (InvalidParameterValue) when calling the RunInstances operation: The parameter is incorrect", + ) + self.assertEqual(context.exception.status_code, 400) + self.assertEqual(context.exception.error_code, "AWS_CLIENT_ERROR") + self.mock_ec2_client.run_instances.assert_called_once()