Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import uvicorn

if __name__ == "__main__":
uvicorn.run("src.app:app", host="0.0.0.0", port=8020, reload=True)

12 changes: 12 additions & 0 deletions master_ws_client_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import asyncio
from uuid import UUID
from src.master_node.services.websocket_client_service import master_client_ws

async def main():
# Connect to ingress router
await master_client_ws.connect(max_reconnect_attempts=3)
# Start listening for messages
await master_client_ws.listen_for_messages()

# if __name__ == "__main__":
# asyncio.run(main())
22 changes: 22 additions & 0 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import asyncio
from fastapi import FastAPI, WebSocket
from contextlib import asynccontextmanager
from uuid import UUID
from src.master_node.services.master_node_websocket_server_service import master_node_ws_server
from master_ws_client_runner import main


@asynccontextmanager
async def lifespan(app: FastAPI):
print("startup")
task = asyncio.create_task(main()) # run main in background
yield
print("shutdown")
task.cancel() # stop the task

app = FastAPI(lifespan=lifespan)

@app.websocket("/ws/connect/{worker_id}")
async def master_node_websocket_connect(websocket: WebSocket, worker_id: UUID):
await master_node_ws_server.connect(worker_id, websocket)
print(f"Worker node: {worker_id}")
11 changes: 0 additions & 11 deletions src/master_node/app.py

This file was deleted.

6 changes: 6 additions & 0 deletions src/master_node/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from dotenv import load_dotenv
import os

load_dotenv()
CORE_API_URI=os.getenv("CORE_API_URI")
INGRESS_ROUTER_URI=os.getenv("INGRESS_ROUTER_URI")
34 changes: 34 additions & 0 deletions src/master_node/models/ingress_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import re
from typing import Optional
from pydantic import UUID4, BaseModel, Field, field_validator
import ipaddress

label = r"(?!-)[A-Za-z0-9-]{1,63}(?<!-)"
# Require at least one dot, last part must be TLD of letters only
domain_pattern = re.compile(rf"^{label}(?:\.{label})*\.[A-Za-z]{{2,}}$")

class IngressRouter(BaseModel):
# ingress_id: UUID4 = Field(...)
ingress_address: str = Field(...)

@field_validator('ingress_address')
@classmethod
def valid_ip_address(cls, v: str):
v = v.strip()

if not v:
raise ValueError("ingress_address cannot be empty")

try:
ipaddress.ip_address(v)
return v
except ValueError:
...

if domain_pattern.match(v):
return v

raise ValueError("ingress_address should contain a valid IP Address or domain name")

class UpdateIngressRouter(BaseModel):
ingress_address: Optional[str] = None
17 changes: 12 additions & 5 deletions src/master_node/models/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,21 @@ class WebsocketMessage(BaseModel):

class JobRequestPayload(BaseModel):
request_id: UUID = Field(...)
job_id: str
master_id: UUID = Field(...)
worker_id: UUID = Field(...)
method: Optional[MethodEnum] = Field(None, description="HTTP method")
path: str = Field(...)
headers: Optional[Dict[str, Any]] = Field(default_factory=dict, description="HTTP headers")
params: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Query parameters")
body: Dict[str, Any] = Field(..., description="Main job data as a JSON object")
body: Any | None = Field(..., description="Main job data as a JSON object")

class JobResponsePayload(BaseModel):
request_id: UUID = Field(...)
status: str = Field(..., description='"ok" or "error"')
result: Dict[str, Any] = Field(default_factory=dict, description="User's actual output in JSON")
error: Optional[str] = Field(None, description="Error details if status='error'")
meta: Dict[str, Any] = Field(default_factory=dict)
job_id: str
master_id: UUID = Field(...)
worker_id: UUID = Field(...)
status_code: int = Field(...)
body: Any = Field(default_factory=dict, description="Output may be error or not")
meta: Dict[str, Any] = Field(default_factory=dict)
headers: Dict[str, Any] = Field(default_factory=dict)
18 changes: 12 additions & 6 deletions src/master_node/services/master_node_websocket_server_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def connect(self, worker_id: UUID, websocket: WebSocket):
await websocket.accept()
worker_key = str(worker_id)
self.connections[worker_key] = websocket

print("/n/n/n connect", self.connections)
try:
await self._handle_worker_websocket_message(worker_id, websocket)
except WebSocketDisconnect:
Expand Down Expand Up @@ -97,14 +97,17 @@ async def send_job_rpc_to_worker_node(
timeout: float = 30.0
) -> JobResponsePayload:

print("workers: ", self.connections)
print("handling job in server")

if not self.is_worker_connected(worker_id):
raise WorkerNotConnectedError(f"Worker {worker_id} not connected")

ws = self._get_websocket_or_raise(worker_id)

request_id = job_payload.request_id

response_future: Future[Any] = asyncio.Future()
response_future: Future[JobResponsePayload] = asyncio.Future()
self.pending_requests[str(request_id)] = response_future

ws_message = WebsocketMessage(
Expand All @@ -114,10 +117,11 @@ async def send_job_rpc_to_worker_node(
)

try:
await ws.send_json(ws_message.model_dump())
response = await asyncio.wait_for(response_future, timeout=timeout)
await ws.send_json(ws_message.model_dump(mode="json"))
# response: JobResponsePayload = await asyncio.wait_for(response_future, timeout)
response: JobResponsePayload = await asyncio.wait_for(response_future, None)

return JobResponsePayload(**response)
return response

except ValidationError as ve:
raise InvalidWorkerResponseError(f"Invalid response from worker: {ve}")
Expand Down Expand Up @@ -248,4 +252,6 @@ def is_worker_connected(self, worker_id: UUID) -> bool:
def get_pending_requests_count(self) -> int:
"""Get count of pending RPC requests"""

return len(self.pending_requests)
return len(self.pending_requests)

master_node_ws_server = MasterNodeWebsocketServerService()
170 changes: 170 additions & 0 deletions src/master_node/services/websocket_client_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import json
# import httpx
import websockets
import asyncio
from typing import Any, Dict, Optional
from uuid import uuid4, UUID
from websockets.exceptions import ConnectionClosed, WebSocketException
from src.master_node.config import INGRESS_ROUTER_URI
# from src.master_node.models.ingress_router import IngressRouter
from src.master_node.models.payloads import WebsocketMessage, MessageType, JobRequestPayload, JobResponsePayload
from src.master_node.services.master_node_websocket_server_service import master_node_ws_server

class MasterNodeDiscoveryError(Exception):
...

class MasterNodeNotFound(MasterNodeDiscoveryError):
...

class MasterNodeServerError(MasterNodeDiscoveryError):
...

class MasterNodeInvalidResponse(MasterNodeDiscoveryError):
...

class WebsocketClientService:
def __init__(self,
master_id: UUID,
max_reconnect_attempts: int = 3,
):
self.master_id = str(master_id)
self.websocket = None
self.max_reconnect_attempts = max_reconnect_attempts
self.current_websocket_url = None

async def connect(
self,
max_reconnect_attempts: int,
max_rediscoveries: int = 2,
websocket_url: Optional[str] = None
) -> None:
if not websocket_url:
websocket_url = await self.discover_ingress_router()

current_address = websocket_url
rediscoveries = 0

while rediscoveries <= max_rediscoveries:
rediscoveries += 1
for attempt in range(1, self.max_reconnect_attempts + 1):
try:
self.websocket = await websockets.connect(current_address)
self.current_websocket_url = current_address
return

except ConnectionClosed:
if attempt == self.max_reconnect_attempts:
new_address = await self.discover_ingress_router()
if new_address != current_address:
current_address = new_address

break

except Exception as e:
print(f"Unexpected error during WebSocket connection (attempt {attempt}): {e}")
else:
raise ConnectionError(f"Failed to connect after {max_reconnect_attempts} attempts")

async def listen_for_messages(self) -> None:
if not self.websocket:
raise RuntimeError("Not connected to WebSocket server")

while self.websocket:
message: str | bytes = ""
try:
message = await self.websocket.recv()
data = json.loads(message)
print(f"Received JSON message: {data}")

# Serialize message
data = WebsocketMessage(**data)
if data.type == MessageType.JOB_REQUEST:
await self.handle_job_rpc_request(JobRequestPayload(**data.payloads))

except (ConnectionClosed, WebSocketException):
await self.disconnect()
await self.connect(self.max_reconnect_attempts)
except json.JSONDecodeError:
print(f"Received non-JSON message: {message}")
except Exception:
await self.disconnect()
raise

await self.disconnect()

async def handle_job_rpc_request(self, job_request: JobRequestPayload):
print(master_node_ws_server.get_connected_workers())
print("handling job in client")
response = await master_node_ws_server.send_job_rpc_to_worker_node(
worker_id=job_request.worker_id,
job_payload=job_request
)

message = WebsocketMessage(
request_id=job_request.request_id,
type=MessageType.JOB_RESPONSE,
payloads=response
)

await self.send_message(message=message)

async def send_message(self, message: WebsocketMessage) -> None:
if not self.websocket:
raise RuntimeError("Not connected to WebSocket server")

max_send_attempts = 2
current_attempts = 0

while current_attempts < max_send_attempts:
try:
await self.websocket.send(json.dumps(message.model_dump(mode="json")))
print(f"Sent message: {message}")
return

except (ConnectionClosed, WebSocketException):
await self.disconnect()
await self.connect(self.max_reconnect_attempts)
except Exception:
current_attempts += 1
if current_attempts < max_send_attempts:
continue
raise

async def disconnect(self) -> None:
if self.websocket:
try:
await self.websocket.close()
finally:
self.websocket = None

def is_connected(self) -> bool:
"""
Check if the websocket is currently connected
"""

return self.websocket is not None

async def discover_ingress_router(self) -> str:
# async with httpx.AsyncClient() as client:
# try:
# response = await client.get(f"{CORE_API_URI}/master-node/discover")
# if response.status_code == 404:
# raise MasterNodeNotFound("Master node discovery endpoint returned 404 Not Found")
# elif 500 <= response.status_code < 600:
# raise MasterNodeServerError(f"Master node discovery failed with status {response.status_code}")
# response.raise_for_status()
# ingress_router_data = response.json()
# try:
# ingress_router = IngressRouter(**ingress_router_data)
# except Exception as e:
# raise MasterNodeInvalidResponse(f"Invalid master node data: {e}")

# master_address = str(ingress_router.ingress_address)
# except httpx.RequestError as e:
# raise MasterNodeDiscoveryError(f"HTTP request failed: {e}") from e

websocket_url = f"ws://{INGRESS_ROUTER_URI}/ws/connect/{self.master_id}"
print(f"Discovered ingress router websocket at: {websocket_url}")
return websocket_url

master_client_ws = WebsocketClientService(master_id=UUID("550e8400-e29b-41d4-a716-446655440000"))