diff --git a/main.py b/main.py index e69de29..910521b 100644 --- a/main.py +++ b/main.py @@ -0,0 +1,5 @@ +import uvicorn + +if __name__ == "__main__": + uvicorn.run("src.app:app", host="0.0.0.0", port=8020, reload=True) + \ No newline at end of file diff --git a/master_ws_client_runner.py b/master_ws_client_runner.py new file mode 100644 index 0000000..2182ec3 --- /dev/null +++ b/master_ws_client_runner.py @@ -0,0 +1,12 @@ +import asyncio +from uuid import UUID +from src.master_node.services.websocket_client_service import master_client_ws + +async def main(): + # Connect to ingress router + await master_client_ws.connect(max_reconnect_attempts=3) + # Start listening for messages + await master_client_ws.listen_for_messages() + +# if __name__ == "__main__": +# asyncio.run(main()) \ No newline at end of file diff --git a/src/app.py b/src/app.py new file mode 100644 index 0000000..25e45cc --- /dev/null +++ b/src/app.py @@ -0,0 +1,22 @@ +import asyncio +from fastapi import FastAPI, WebSocket +from contextlib import asynccontextmanager +from uuid import UUID +from src.master_node.services.master_node_websocket_server_service import master_node_ws_server +from master_ws_client_runner import main + + +@asynccontextmanager +async def lifespan(app: FastAPI): + print("startup") + task = asyncio.create_task(main()) # run main in background + yield + print("shutdown") + task.cancel() # stop the task + +app = FastAPI(lifespan=lifespan) + +@app.websocket("/ws/connect/{worker_id}") +async def master_node_websocket_connect(websocket: WebSocket, worker_id: UUID): + await master_node_ws_server.connect(worker_id, websocket) + print(f"Worker node: {worker_id}") \ No newline at end of file diff --git a/src/master_node/app.py b/src/master_node/app.py deleted file mode 100644 index 6b4d273..0000000 --- a/src/master_node/app.py +++ /dev/null @@ -1,11 +0,0 @@ -from fastapi import FastAPI, WebSocket -from uuid import UUID -from .services.master_node_websocket_server_service import MasterNodeWebsocketServerService - -app = FastAPI() -ws_service = MasterNodeWebsocketServerService() - -@app.websocket("/ws/connect/{worker_id}") -async def master_node_websocket_connect(websocket: WebSocket, worker_id: UUID): - await ws_service.connect(worker_id, websocket) - print(f"Worker node: {worker_id}") \ No newline at end of file diff --git a/src/master_node/config.py b/src/master_node/config.py index e69de29..cf3c5f2 100644 --- a/src/master_node/config.py +++ b/src/master_node/config.py @@ -0,0 +1,6 @@ +from dotenv import load_dotenv +import os + +load_dotenv() +CORE_API_URI=os.getenv("CORE_API_URI") +INGRESS_ROUTER_URI=os.getenv("INGRESS_ROUTER_URI") \ No newline at end of file diff --git a/src/master_node/models/ingress_router.py b/src/master_node/models/ingress_router.py new file mode 100644 index 0000000..5466a53 --- /dev/null +++ b/src/master_node/models/ingress_router.py @@ -0,0 +1,34 @@ +import re +from typing import Optional +from pydantic import UUID4, BaseModel, Field, field_validator +import ipaddress + +label = r"(?!-)[A-Za-z0-9-]{1,63}(? JobResponsePayload: + print("workers: ", self.connections) + print("handling job in server") + if not self.is_worker_connected(worker_id): raise WorkerNotConnectedError(f"Worker {worker_id} not connected") @@ -104,7 +107,7 @@ async def send_job_rpc_to_worker_node( request_id = job_payload.request_id - response_future: Future[Any] = asyncio.Future() + response_future: Future[JobResponsePayload] = asyncio.Future() self.pending_requests[str(request_id)] = response_future ws_message = WebsocketMessage( @@ -114,10 +117,11 @@ async def send_job_rpc_to_worker_node( ) try: - await ws.send_json(ws_message.model_dump()) - response = await asyncio.wait_for(response_future, timeout=timeout) + await ws.send_json(ws_message.model_dump(mode="json")) + # response: JobResponsePayload = await asyncio.wait_for(response_future, timeout) + response: JobResponsePayload = await asyncio.wait_for(response_future, None) - return JobResponsePayload(**response) + return response except ValidationError as ve: raise InvalidWorkerResponseError(f"Invalid response from worker: {ve}") @@ -248,4 +252,6 @@ def is_worker_connected(self, worker_id: UUID) -> bool: def get_pending_requests_count(self) -> int: """Get count of pending RPC requests""" - return len(self.pending_requests) \ No newline at end of file + return len(self.pending_requests) + +master_node_ws_server = MasterNodeWebsocketServerService() \ No newline at end of file diff --git a/src/master_node/services/websocket_client_service.py b/src/master_node/services/websocket_client_service.py new file mode 100644 index 0000000..cf882d0 --- /dev/null +++ b/src/master_node/services/websocket_client_service.py @@ -0,0 +1,170 @@ +import json +# import httpx +import websockets +import asyncio +from typing import Any, Dict, Optional +from uuid import uuid4, UUID +from websockets.exceptions import ConnectionClosed, WebSocketException +from src.master_node.config import INGRESS_ROUTER_URI +# from src.master_node.models.ingress_router import IngressRouter +from src.master_node.models.payloads import WebsocketMessage, MessageType, JobRequestPayload, JobResponsePayload +from src.master_node.services.master_node_websocket_server_service import master_node_ws_server + +class MasterNodeDiscoveryError(Exception): + ... + +class MasterNodeNotFound(MasterNodeDiscoveryError): + ... + +class MasterNodeServerError(MasterNodeDiscoveryError): + ... + +class MasterNodeInvalidResponse(MasterNodeDiscoveryError): + ... + +class WebsocketClientService: + def __init__(self, + master_id: UUID, + max_reconnect_attempts: int = 3, + ): + self.master_id = str(master_id) + self.websocket = None + self.max_reconnect_attempts = max_reconnect_attempts + self.current_websocket_url = None + + async def connect( + self, + max_reconnect_attempts: int, + max_rediscoveries: int = 2, + websocket_url: Optional[str] = None + ) -> None: + if not websocket_url: + websocket_url = await self.discover_ingress_router() + + current_address = websocket_url + rediscoveries = 0 + + while rediscoveries <= max_rediscoveries: + rediscoveries += 1 + for attempt in range(1, self.max_reconnect_attempts + 1): + try: + self.websocket = await websockets.connect(current_address) + self.current_websocket_url = current_address + return + + except ConnectionClosed: + if attempt == self.max_reconnect_attempts: + new_address = await self.discover_ingress_router() + if new_address != current_address: + current_address = new_address + + break + + except Exception as e: + print(f"Unexpected error during WebSocket connection (attempt {attempt}): {e}") + else: + raise ConnectionError(f"Failed to connect after {max_reconnect_attempts} attempts") + + async def listen_for_messages(self) -> None: + if not self.websocket: + raise RuntimeError("Not connected to WebSocket server") + + while self.websocket: + message: str | bytes = "" + try: + message = await self.websocket.recv() + data = json.loads(message) + print(f"Received JSON message: {data}") + + # Serialize message + data = WebsocketMessage(**data) + if data.type == MessageType.JOB_REQUEST: + await self.handle_job_rpc_request(JobRequestPayload(**data.payloads)) + + except (ConnectionClosed, WebSocketException): + await self.disconnect() + await self.connect(self.max_reconnect_attempts) + except json.JSONDecodeError: + print(f"Received non-JSON message: {message}") + except Exception: + await self.disconnect() + raise + + await self.disconnect() + + async def handle_job_rpc_request(self, job_request: JobRequestPayload): + print(master_node_ws_server.get_connected_workers()) + print("handling job in client") + response = await master_node_ws_server.send_job_rpc_to_worker_node( + worker_id=job_request.worker_id, + job_payload=job_request + ) + + message = WebsocketMessage( + request_id=job_request.request_id, + type=MessageType.JOB_RESPONSE, + payloads=response + ) + + await self.send_message(message=message) + + async def send_message(self, message: WebsocketMessage) -> None: + if not self.websocket: + raise RuntimeError("Not connected to WebSocket server") + + max_send_attempts = 2 + current_attempts = 0 + + while current_attempts < max_send_attempts: + try: + await self.websocket.send(json.dumps(message.model_dump(mode="json"))) + print(f"Sent message: {message}") + return + + except (ConnectionClosed, WebSocketException): + await self.disconnect() + await self.connect(self.max_reconnect_attempts) + except Exception: + current_attempts += 1 + if current_attempts < max_send_attempts: + continue + raise + + async def disconnect(self) -> None: + if self.websocket: + try: + await self.websocket.close() + finally: + self.websocket = None + + def is_connected(self) -> bool: + """ + Check if the websocket is currently connected + """ + + return self.websocket is not None + + async def discover_ingress_router(self) -> str: + # async with httpx.AsyncClient() as client: + # try: + # response = await client.get(f"{CORE_API_URI}/master-node/discover") + # if response.status_code == 404: + # raise MasterNodeNotFound("Master node discovery endpoint returned 404 Not Found") + # elif 500 <= response.status_code < 600: + # raise MasterNodeServerError(f"Master node discovery failed with status {response.status_code}") + # response.raise_for_status() + # ingress_router_data = response.json() + # try: + # ingress_router = IngressRouter(**ingress_router_data) + # except Exception as e: + # raise MasterNodeInvalidResponse(f"Invalid master node data: {e}") + + # master_address = str(ingress_router.ingress_address) + # except httpx.RequestError as e: + # raise MasterNodeDiscoveryError(f"HTTP request failed: {e}") from e + + websocket_url = f"ws://{INGRESS_ROUTER_URI}/ws/connect/{self.master_id}" + print(f"Discovered ingress router websocket at: {websocket_url}") + return websocket_url + +master_client_ws = WebsocketClientService(master_id=UUID("550e8400-e29b-41d4-a716-446655440000")) \ No newline at end of file