Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions backend/lib/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Union
from typing import List, Optional, Union

# Add these models after your existing Pydantic models
class TranscriptionRequest(BaseModel):
Expand All @@ -20,7 +20,7 @@ class ImageUrl(BaseModel):
class MessageContent(BaseModel):
type: str # "text" or "image_url"
text: Optional[str] = None
image_url: Optional[ImageUrl] = None
image_url: Optional[ImageUrl] = None

# chat completion request and response models
class UsageDetails(BaseModel):
Expand All @@ -30,14 +30,14 @@ class UsageDetails(BaseModel):

class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
completion_tokens: Optional[int] = None
total_tokens: int
prompt_tokens_details: Optional[UsageDetails] = None
completion_tokens_details: Optional[UsageDetails] = None

class Message(BaseModel):
role: str
content: Union[str, List[MessageContent]] # Can be string or list for vision
content: Union[str,List[MessageContent]] # Can be string or list for vision
name: Optional[str] = None

class ChatCompletionRequest(BaseModel):
Expand Down Expand Up @@ -69,7 +69,7 @@ class ChatCompletionResponse(BaseModel):
# embeddings request and response models
class EmbeddingInput(BaseModel):
model: str
input: List[str]
input: Union[str,List[str]]
user: Optional[str] = None

class EmbeddingData(BaseModel):
Expand All @@ -88,4 +88,4 @@ class SpeechRequest(BaseModel):
input: str
voice: str = "alloy" # alloy, echo, fable, onyx, nova, shimmer
response_format: Optional[str] = "mp3" # mp3, opus, aac, flac, wav, pcm
speed: Optional[float] = Field(1.0, ge=0.25, le=4.0) # Speed between 0.25 and 4.0
speed: Optional[float] = Field(1.0, ge=0.25, le=4.0) # Speed between 0.25 and 4.0
12 changes: 8 additions & 4 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async def chat_completions(request: ChatCompletionRequest, user_key = Depends(ve
if content_item.image_url.url.startswith(("http://", "https://")):
content_item.image_url.url = await fetch_image_as_base64(content_item.image_url.url)

request_data = request.dict(by_alias=True)
request_data = request.model_dump(by_alias=True)

# Convert to OpenAI format for vision messages
if has_images:
Expand Down Expand Up @@ -160,7 +160,11 @@ async def chat_completions(request: ChatCompletionRequest, user_key = Depends(ve
# Keep OpenAI-compatible parameters only
allowed_params = ["model", "messages", "stream", "max_tokens", "temperature", "top_p", "n", "stop", "presence_penalty", "frequency_penalty", "user"]
request_data = {k: v for k, v in request_data.items() if k in allowed_params and v is not None}


# Some models don't allow both temperature and top_p
if 'temperature' in request_data and 'top_p' in request_data:
# Remove top_p, keep temperature (or vice versa based on your preference)
request_data.pop('top_p')
# Don't truncate messages with images
if model_config['params'].get('max_input_tokens') and not has_images:
# Only truncate text-only messages
Expand Down Expand Up @@ -343,7 +347,7 @@ async def event_generator():
@app.post("/v1/embeddings")
async def create_embedding(request: EmbeddingInput, user_key = Depends(verify_token)):
model_config = get_model_config(request.model, user_key)
request_data = request.dict()
request_data = request.model_dump()

request_data["model"] = model_config['params']['model'] # Maps "devstral" to "devstral:24b"

Expand Down Expand Up @@ -599,4 +603,4 @@ async def list_models(user_key = Depends(verify_token)):

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000)
4 changes: 3 additions & 1 deletion backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ codecarbon
prometheus-client
python-multipart
tiktoken
requests
requests
pytest
httpx
122 changes: 122 additions & 0 deletions backend/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from fastapi.testclient import TestClient
from main import app
from lib.types import ChatCompletionResponse, EmbeddingResponse
import os


token = os.environ["TOKEN"]
print("TOKEN", token )
client = TestClient(app)

b64_image = ""
url_image = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
chatRequest = {
"model": "gemma3", # Replace with actual model name (e.g., "gpt-4")
"messages": [
{
"role": "user", # or "system", "assistant"
"content": "Hello, how are you?", # Replace with your actual message
# "name": "optional_name" # Only needed for named users in some APIs
}
],
"max_tokens": 100, # Replace with your desired max tokens
"temperature": 0.7, # Replace with your desired temperature (0-2)
"top_p": 1,
"n": 1,
"stream": False,
"stop": None, # or ["string"] if you want to stop at certain phrases
"presence_penalty": 0,
"frequency_penalty": 0
# Note: 'usage' is typically part of the response, not the request
}

embeddingRequest = {
"model": "text-embedding-3-small",
"input": [
"Message a embeder"
],
"user": "string"
}


def test_read_main():
response = client.get("/")
assert response.status_code == 404
assert response.json() == {"detail":"Not Found"}
def test_list_models():
response = client.get("/v1/models", headers={"Authorization": "Bearer " + token})
assert response.status_code == 200
data = response.json()
assert isinstance(data["data"], list)
def test_completion_chat():
response = client.post(
"/v1/chat/completions",
headers={"Authorization": "Bearer " + token},
json=chatRequest
)
data = response.json()
# check type using pydantic class
assert ChatCompletionResponse(**data)

def test_embedding_array():
response = client.post(
"/v1/embeddings",
headers={"Authorization": "Bearer " + token},
json=embeddingRequest
,)
data = response.json()
assert EmbeddingResponse(**data)

def test_embedding_string():
embeddingRequest["input"] = "Message to embed"
response = client.post(
"/v1/embeddings",
headers={"Authorization": "Bearer " + token},
json=embeddingRequest
,)
data = response.json()
assert EmbeddingResponse(**data)

# def test_completion_chat_b64Image():
# chatRequest["messages"][0]["content"] = [
# {
# "type": "text",
# "text": "What is image ?"
# },
# {
# "type": "image_url",
# "image_url": {"url": b64_image}
# }
# ]
# print(chatRequest)
# response = client.post(
# "/v1/chat/completions",
# headers={"Authorization": "Bearer " + token},
# json=chatRequest
# )
# data = response.json()
# print(data)
# # check type using pydantic class
# assert ChatCompletionResponse(**data)

# def test_completion_chat_urlImage():
# chatRequest["messages"][0]["content"] = [
# {
# "type": "text",
# "text": "What is image ?"
# },
# {
# "type": "image_url",
# "image_url": {"url": url_image}
# }
# ]
# print(chatRequest)
# response = client.post(
# "/v1/chat/completions",
# headers={"Authorization": "Bearer " + token},
# json=chatRequest
# )
# data = response.json()
# print(data)
# # check type using pydantic class
# assert ChatCompletionResponse(**data)
80 changes: 0 additions & 80 deletions backend/tests.py

This file was deleted.

11 changes: 4 additions & 7 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,17 @@ services:
context: ./backend
dockerfile: Dockerfile
container_name: ai-proxy-test
depends_on:
- api
volumes:
- ./backend:/app
- ./config.yaml:/config.yaml
env_file: .env
environment:
- PYTHONUNBUFFERED=1
- API_URL=http://api:8000
- API_TOKEN=autobot3000
command: ["python", "tests.py"]
command: ["pytest"]

api:
profiles:
- api
- test
image: erasme/ai-proxy:latest
build:
context: ./backend
Expand Down Expand Up @@ -76,4 +73,4 @@ services:


volumes:
grafana-storage:
grafana-storage: