Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_language_version:
python: python3.10
python: python3.12.5
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
Expand Down
73 changes: 22 additions & 51 deletions app/api/v1/endpoints/machine_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import requests
from typing import Annotated, List

from fastapi import APIRouter, Depends, status
from fastapi import APIRouter, Depends, status, HTTPException
from fastapi.security import OAuth2PasswordBearer

from core.config import settings
Expand All @@ -29,6 +29,7 @@
MachineNotFoundError,
MachineOperationFailedError,
)
from utils.ollama import ollama_pull_model, check_model_on_machine

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -247,34 +248,19 @@ async def pull_gpu_model(
model_name = selection_request.model_name

machine_public_ip = ec2_service.get_instance_public_ip(machine_id)
url = f"http://{machine_public_ip}:11434/api/pull"

model_request = {
"model": model_name,
"stream": False,
}
pull_response = ollama_pull_model(machine_public_ip, model_name)
if pull_response["status"] == "failed":
raise HTTPException(
status_code=500,
detail=f"Failed to pull model {model_name} to instance.",
)

response = requests.post(
url=url,
json=model_request,
)
check_response = check_model_on_machine(machine_public_ip, model_name)
if check_response["status"] == "failed":
raise HTTPException(status_code=404, detail=check_response["error"])

if response.status_code == 200:
response_data = response.json()
response_data["message"] = f"Model {model_name} pulled to instance"
return response_data
elif response.status_code == 500:
error_message = response.json().get("error", "Unknown error")
if error_message == "pull model manifest: file does not exist":
logger.error(f"Error: {error_message}")
return {"status": "failed", "error": error_message}
else:
response.raise_for_status()
else:
response.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred: {e}")
return {"error": str(e)}
return {"message": check_response["message"]}
except Exception as e:
logger.error(f"An error occurred: {e}")
return {"error": str(e)}
Expand Down Expand Up @@ -437,34 +423,19 @@ async def pull_cpu_model(
model_name = selection_request.model_name

machine_public_ip = ec2_service.get_instance_public_ip(machine_id)
url = f"http://{machine_public_ip}:11434/api/pull"

model_request = {
"model": model_name,
"stream": False,
}
pull_response = ollama_pull_model(machine_public_ip, model_name)
if pull_response["status"] == "failed":
raise HTTPException(
status_code=500,
detail=f"Failed to pull model {model_name} to instance.",
)

response = requests.post(
url=url,
json=model_request,
)
check_response = check_model_on_machine(machine_public_ip, model_name)
if check_response["status"] == "failed":
raise HTTPException(status_code=404, detail=check_response["error"])

if response.status_code == 200:
response_data = response.json()
response_data["message"] = f"Model {model_name} pulled to instance"
return response_data
elif response.status_code == 500:
error_message = response.json().get("error", "Unknown error")
if error_message == "pull model manifest: file does not exist":
logger.error(f"Error: {error_message}")
return {"status": "failed", "error": error_message}
else:
response.raise_for_status()
else:
response.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred: {e}")
return {"error": str(e)}
return {"message": check_response["message"]}
except Exception as e:
logger.error(f"An error occurred: {e}")
return {"error": str(e)}
Expand Down
52 changes: 52 additions & 0 deletions app/utils/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
This module defines helper functions for interacting with the OLLAMA API.
"""

import requests
from requests.exceptions import RequestException


def ollama_pull_model(machine_public_ip, model_name):
url = f"http://{machine_public_ip}:11434/api/pull"
model_request = {
"model": model_name,
"stream": False,
}
try:
response = requests.post(
url=url,
json=model_request,
)
response.raise_for_status()
return {"status": response.status_code, "response": response.json()}
except RequestException as e:
error_message = f"An error occurred during pulling Ollama Model request: {e}"
return {
"status": "failed",
"message": error_message,
}


def check_model_on_machine(machine_public_ip, model_name):
check_url = f"http://{machine_public_ip}:11434/api/tags"
try:
check_response = requests.get(url=check_url)
check_response.raise_for_status()

if check_response.status_code == 200:
check_response_data = check_response.json()
for model in check_response_data.get("models", []):
if model_name in model["name"]:
return {
"status": "success",
"message": f"Model {model_name} pulled to instance",
}
return {
"status": "failed",
"error": f"Model {model_name} not found on instance",
}
except RequestException as e:
return {
"status": "failed",
"message": f"An error occurred during checking Ollama Model on machine: {e}",
}
Loading