diff --git a/.vscode/launch.json b/.vscode/launch.json index 977eeda..2dcb99c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,6 +10,7 @@ "request": "launch", "django": true, "module": "mcp_bridge.main", + "pythonArgs": ["-Xutf8"] } ] -} \ No newline at end of file +} diff --git a/mcp_bridge/__init__.py b/mcp_bridge/__init__.py index 08d79c0..93b60a1 100644 --- a/mcp_bridge/__init__.py +++ b/mcp_bridge/__init__.py @@ -1 +1 @@ -__version__ = '0.5.1' \ No newline at end of file +__version__ = '0.5.1' diff --git a/mcp_bridge/config/final.py b/mcp_bridge/config/final.py index 2201f01..5b3940f 100644 --- a/mcp_bridge/config/final.py +++ b/mcp_bridge/config/final.py @@ -7,6 +7,9 @@ class InferenceServer(BaseModel): + type: Literal["openai", "openrouter", "gemini"] = Field( + "openai", description="Type of inference server" + ) # used to apply data mappers base_url: str = Field( default="http://localhost:11434/v1", description="Base URL of the inference server", @@ -24,14 +27,19 @@ class Logging(BaseModel): class SamplingModel(BaseModel): model: Annotated[str, Field(description="Name of the sampling model")] - intelligence: Annotated[float, Field(description="Intelligence of the sampling model")] = 0.5 + intelligence: Annotated[ + float, Field(description="Intelligence of the sampling model") + ] = 0.5 cost: Annotated[float, Field(description="Cost of the sampling model")] = 0.5 speed: Annotated[float, Field(description="Speed of the sampling model")] = 0.5 class Sampling(BaseModel): timeout: Annotated[int, Field(description="Timeout for sampling requests")] = 10 - models: Annotated[list[SamplingModel], Field(description="List of sampling models")] = [] + models: Annotated[ + list[SamplingModel], Field(description="List of sampling models") + ] = [] + class SSEMCPServer(BaseModel): # TODO: expand this once I find a good definition for this diff --git a/mcp_bridge/endpoints.py b/mcp_bridge/endpoints.py index 23326cf..e7cb570 100644 --- a/mcp_bridge/endpoints.py +++ b/mcp_bridge/endpoints.py @@ -2,14 +2,16 @@ from lmos_openai_types import CreateChatCompletionRequest, CreateCompletionRequest +from mcp_bridge.config.final import InferenceServer from mcp_bridge.openai_clients import ( - client, completions, chat_completions, streaming_chat_completions, ) +from mcp_bridge.http_clients import get_client from mcp_bridge.openapi_tags import Tag +from mcp_bridge.config import config router = APIRouter(prefix="/v1", tags=[Tag.openai]) @@ -34,6 +36,45 @@ async def openai_chat_completions(request: CreateChatCompletionRequest): @router.get("/models") async def models(): - """List models""" - response = await client.get("/models") + """List models. + + This is a passthrough to the inference server and returns the same response json.""" + + # this is an ugly hack to fix an upstream bug in gemini upstream + if config.inference_server.type == "gemini": + return list_gemini_models() + + response = await get_client().get("/models") return response.json() + +def list_gemini_models(): + """temp hack to fix gemini bug""" + return { + "object": "list", + "data": [ + { + "id": "gemini-2.0-flash-exp", + "object": "model", + "created": 1686935002, + "owned_by": "google", + }, + { + "id": "gemini-1.5-flash", + "object": "model", + "created": 1686935002, + "owned_by": "google", + }, + { + "id": "gemini-1.5-flash-8b", + "object": "model", + "created": 1686935002, + "owned_by": "google", + }, + { + "id": "gemini-1.5-pro", + "object": "model", + "created": 1686935002, + "owned_by": "google", + } + ], + } \ No newline at end of file diff --git a/mcp_bridge/http_clients/__init__.py b/mcp_bridge/http_clients/__init__.py new file mode 100644 index 0000000..81e0d8f --- /dev/null +++ b/mcp_bridge/http_clients/__init__.py @@ -0,0 +1,35 @@ +from httpx import AsyncClient +from mcp_bridge.config import config + + +# change this if you want to hard fork the repo +# its used to show ranking on openrouter and other inference providers +BRIDGE_REPO_URL = "https://github.com/SecretiveShell/MCP-Bridge" +BRIDGE_APP_TITLE = "MCP Bridge" + + +def get_client() -> AsyncClient: + client: AsyncClient = AsyncClient( + base_url=config.inference_server.base_url, + headers={"Content-Type": "application/json"}, + timeout=10000, + ) + + # generic openai + if config.inference_server.type == "openai": + client.headers["Authorization"] = rf"Bearer {config.inference_server.api_key}" + return client + + # openrouter + if config.inference_server.type == "openrouter": + client.headers["Authorization"] = rf"Bearer {config.inference_server.api_key}" + client.headers["HTTP-Referer"] = BRIDGE_REPO_URL + client.headers["X-Title"] = BRIDGE_APP_TITLE + return client + + # gemini models + if config.inference_server.type == "gemini": + client.headers["Authorization"] = rf"Bearer {config.inference_server.api_key}" + return client + + raise NotImplementedError("Inference Server Type not supported") diff --git a/mcp_bridge/inference_engine_mappers/chat/gemini/request.py b/mcp_bridge/inference_engine_mappers/chat/gemini/request.py new file mode 100644 index 0000000..3298a06 --- /dev/null +++ b/mcp_bridge/inference_engine_mappers/chat/gemini/request.py @@ -0,0 +1,8 @@ +from lmos_openai_types import CreateChatCompletionRequest + + +def chat_completion_gemini_request(data: CreateChatCompletionRequest) -> dict: + + dumped_data = data.model_dump(exclude_defaults=True, exclude_none=True, exclude_unset=True) + + return dumped_data diff --git a/mcp_bridge/inference_engine_mappers/chat/gemini/response.py b/mcp_bridge/inference_engine_mappers/chat/gemini/response.py new file mode 100644 index 0000000..7dd0d0d --- /dev/null +++ b/mcp_bridge/inference_engine_mappers/chat/gemini/response.py @@ -0,0 +1,10 @@ +from lmos_openai_types import CreateChatCompletionResponse + + +def chat_completion_gemini_response(data: dict) -> CreateChatCompletionResponse: + + if "id" not in data or data["id"] is "": + data["id"] = "default-id" + + validated_data = CreateChatCompletionResponse.model_validate(data) + return validated_data \ No newline at end of file diff --git a/mcp_bridge/inference_engine_mappers/chat/gemini/stream_response.py b/mcp_bridge/inference_engine_mappers/chat/gemini/stream_response.py new file mode 100644 index 0000000..f1b0ec4 --- /dev/null +++ b/mcp_bridge/inference_engine_mappers/chat/gemini/stream_response.py @@ -0,0 +1,14 @@ +from lmos_openai_types import CreateChatCompletionStreamResponse +from loguru import logger + + +def chat_completion_gemini_stream_response( + data: dict, +) -> CreateChatCompletionStreamResponse: # type: ignore + + logger.debug(f"data: {data}") + + if "id" not in data or data["id"] == "": + data["id"] = "default-id" + + return CreateChatCompletionStreamResponse.model_validate(data) diff --git a/mcp_bridge/inference_engine_mappers/chat/generic.py b/mcp_bridge/inference_engine_mappers/chat/generic.py new file mode 100644 index 0000000..16c19a8 --- /dev/null +++ b/mcp_bridge/inference_engine_mappers/chat/generic.py @@ -0,0 +1,19 @@ +from lmos_openai_types import ( + CreateChatCompletionRequest, + CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, +) + + +def chat_completion_generic_request(data: CreateChatCompletionRequest) -> dict: + return data.model_dump(exclude_defaults=True, exclude_none=True, exclude_unset=True) + + +def chat_completion_generic_response(data: dict) -> CreateChatCompletionResponse: + return CreateChatCompletionResponse.model_validate(data) + + +def chat_completion_generic_stream_response( + data: dict, +) -> CreateChatCompletionStreamResponse: + return CreateChatCompletionStreamResponse.model_validate(data) diff --git a/mcp_bridge/inference_engine_mappers/chat/openrouter/request.py b/mcp_bridge/inference_engine_mappers/chat/openrouter/request.py new file mode 100644 index 0000000..0244d38 --- /dev/null +++ b/mcp_bridge/inference_engine_mappers/chat/openrouter/request.py @@ -0,0 +1,36 @@ +import json +import secrets +from typing import Any, cast +from lmos_openai_types import CreateChatCompletionRequest +from loguru import logger + + +def chat_completion_openrouter_request(data: CreateChatCompletionRequest) -> dict: + + dumped_data = data.model_dump(exclude_defaults=True, exclude_none=True, exclude_unset=True) + + # make sure we have a tool call id for each tool call + try: + for message in dumped_data["messages"]: + + message = cast(dict[str, Any], message) + + if message["role"] == "assistant": + if message.get("tool_calls") is None: + continue + for tool_call in message["tool_calls"]: + tool_call["tool_call_id"] = tool_call.get("id", secrets.token_hex(16)) + + if message["role"] == "tool": + if message.get("tool_call_id") is None: + message["tool_call_id"] = secrets.token_hex(16) + if message.get("id") is None: + message["id"] = message["tool_call_id"] + + except Exception as e: + print(e) + + logger.debug(f"dumped data: {dumped_data}") + logger.debug(f"json dumped data: {json.dumps(dumped_data)}") + + return dumped_data diff --git a/mcp_bridge/inference_engine_mappers/chat/openrouter/response.py b/mcp_bridge/inference_engine_mappers/chat/openrouter/response.py new file mode 100644 index 0000000..1a96bcf --- /dev/null +++ b/mcp_bridge/inference_engine_mappers/chat/openrouter/response.py @@ -0,0 +1,21 @@ +import secrets +from typing import cast +from lmos_openai_types import CreateChatCompletionResponse +from loguru import logger + + +def chat_completion_openrouter_response(data: dict) -> CreateChatCompletionResponse: + validated_data = CreateChatCompletionResponse.model_validate(data) + + # make sure tool call ids are not none + for choice in validated_data.choices: + if choice.message.tool_calls is None: + continue + for tool_call in choice.message.tool_calls: + logger.error(f"tool call: {tool_call[1]}") + for calls in tool_call[1]: + if calls.id is None: + calls.id = secrets.token_hex(16) + + logger.debug(f"validated data: {validated_data.model_dump_json()}") + return validated_data \ No newline at end of file diff --git a/mcp_bridge/inference_engine_mappers/chat/openrouter/stream_response.py b/mcp_bridge/inference_engine_mappers/chat/openrouter/stream_response.py new file mode 100644 index 0000000..1c0f628 --- /dev/null +++ b/mcp_bridge/inference_engine_mappers/chat/openrouter/stream_response.py @@ -0,0 +1,13 @@ +from lmos_openai_types import CreateChatCompletionStreamResponse + + +def chat_completion_openrouter_stream_response( + data: dict, +) -> CreateChatCompletionStreamResponse: # type: ignore + try: + data["choices"][0]["finish_reason"] = data["choices"][0][ + "finish_reason" + ].lower() # type: ignore + except Exception: + pass + return CreateChatCompletionStreamResponse.model_validate(data) diff --git a/mcp_bridge/inference_engine_mappers/chat/requester.py b/mcp_bridge/inference_engine_mappers/chat/requester.py new file mode 100644 index 0000000..aa3a47c --- /dev/null +++ b/mcp_bridge/inference_engine_mappers/chat/requester.py @@ -0,0 +1,21 @@ +from mcp_bridge.inference_engine_mappers.chat.gemini.request import chat_completion_gemini_request +from .generic import chat_completion_generic_request +from .openrouter.request import chat_completion_openrouter_request +from lmos_openai_types import CreateChatCompletionRequest +from mcp_bridge.config import config + + +def chat_completion_requester(data: CreateChatCompletionRequest) -> dict: + client_type = config.inference_server.type + + match client_type: + # apply incoming data mappers + case "openai": + return chat_completion_generic_request(data) + case "openrouter": + # TODO: implement openrouter requester + return chat_completion_openrouter_request(data) + case "gemini": + return chat_completion_gemini_request(data) + case _: + return chat_completion_generic_request(data) \ No newline at end of file diff --git a/mcp_bridge/inference_engine_mappers/chat/responder.py b/mcp_bridge/inference_engine_mappers/chat/responder.py new file mode 100644 index 0000000..f74ea62 --- /dev/null +++ b/mcp_bridge/inference_engine_mappers/chat/responder.py @@ -0,0 +1,21 @@ +from mcp_bridge.inference_engine_mappers.chat.gemini.response import chat_completion_gemini_response +from .generic import chat_completion_generic_response +from .openrouter.response import chat_completion_openrouter_response +from lmos_openai_types import CreateChatCompletionResponse +from mcp_bridge.config import config + + +def chat_completion_responder(data: dict) -> CreateChatCompletionResponse: + client_type = config.inference_server.type + + match client_type: + # apply incoming data mappers + case "openai": + return chat_completion_generic_response(data) + case "openrouter": + # TODO: implement openrouter responser + return chat_completion_openrouter_response(data) + case "gemini": + return chat_completion_gemini_response(data) + case _: + return chat_completion_generic_response(data) diff --git a/mcp_bridge/inference_engine_mappers/chat/stream_responder.py b/mcp_bridge/inference_engine_mappers/chat/stream_responder.py new file mode 100644 index 0000000..fafb166 --- /dev/null +++ b/mcp_bridge/inference_engine_mappers/chat/stream_responder.py @@ -0,0 +1,21 @@ +from mcp_bridge.inference_engine_mappers.chat.gemini.stream_response import chat_completion_gemini_stream_response +from .generic import chat_completion_generic_stream_response +from .openrouter.stream_response import chat_completion_openrouter_stream_response +from lmos_openai_types import CreateChatCompletionStreamResponse +from mcp_bridge.config import config + + +def chat_completion_stream_responder(data: dict) -> CreateChatCompletionStreamResponse: + client_type = config.inference_server.type + + match client_type: + # apply incoming data mappers + case "openai": + return chat_completion_generic_stream_response(data) + case "openrouter": + # TODO: implement openrouter responser + return chat_completion_openrouter_stream_response(data) + case "gemini": + return chat_completion_gemini_stream_response(data) + case _: + return chat_completion_generic_stream_response(data) diff --git a/mcp_bridge/main.py b/mcp_bridge/main.py index 405223f..0a12931 100644 --- a/mcp_bridge/main.py +++ b/mcp_bridge/main.py @@ -10,6 +10,7 @@ from mcp_bridge.config import config from loguru import logger + def create_app() -> FastAPI: """ Create and configure the FastAPI application. @@ -46,11 +47,16 @@ def create_app() -> FastAPI: return app + app = create_app() + def run(): import uvicorn + from mcp_bridge.config import config + uvicorn.run(app, host=config.network.host, port=config.network.port) + if __name__ == "__main__": - run() \ No newline at end of file + run() diff --git a/mcp_bridge/mcpManagement/resources.py b/mcp_bridge/mcpManagement/resources.py index 0a2c583..c9c7197 100644 --- a/mcp_bridge/mcpManagement/resources.py +++ b/mcp_bridge/mcpManagement/resources.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter from mcp_bridge.mcp_clients.McpClientManager import ClientManager from mcp.types import ListResourcesResult diff --git a/mcp_bridge/mcp_clients/AbstractClient.py b/mcp_bridge/mcp_clients/AbstractClient.py index 2fc1532..b07ab01 100644 --- a/mcp_bridge/mcp_clients/AbstractClient.py +++ b/mcp_bridge/mcp_clients/AbstractClient.py @@ -41,9 +41,13 @@ async def _session_maintainer(self): try: await self._maintain_session() except FileNotFoundError as e: - logger.error(f"failed to maintain session for {self.name}: file {e.filename} not found.") + logger.error( + f"failed to maintain session for {self.name}: file {e.filename} not found." + ) except Exception as e: - logger.error(f"failed to maintain session for {self.name}: {type(e)} {e.args}") + logger.error( + f"failed to maintain session for {self.name}: {type(e)} {e.args}" + ) logger.debug(f"restarting session for {self.name}") await asyncio.sleep(0.5) @@ -139,10 +143,11 @@ async def _wait_for_session(self, timeout: int = 5, http_error: bool = True): except asyncio.TimeoutError: if http_error: raise HTTPException( - status_code=500, detail=f"Could not connect to MCP server \"{self.name}\"." + status_code=500, + detail=f'Could not connect to MCP server "{self.name}".', ) - raise TimeoutError(f"Could not connect to MCP server \"{self.name}\"." ) + raise TimeoutError(f'Could not connect to MCP server "{self.name}".') assert self.session is not None, "Session is None" diff --git a/mcp_bridge/mcp_clients/McpClientManager.py b/mcp_bridge/mcp_clients/McpClientManager.py index 7bf6c89..f30eff8 100644 --- a/mcp_bridge/mcp_clients/McpClientManager.py +++ b/mcp_bridge/mcp_clients/McpClientManager.py @@ -40,7 +40,7 @@ async def construct_client(self, name, server_config) -> client_types: client = SseClient(name, server_config) # type: ignore await client.start() return client - + if isinstance(server_config, DockerMCPServer): client = DockerClient(name, server_config) await client.start() @@ -56,7 +56,6 @@ def get_clients(self): async def get_client_from_tool(self, tool: str): for name, client in self.get_clients(): - # client cannot have tools if it is not connected if not client.session: continue @@ -71,7 +70,6 @@ async def get_client_from_tool(self, tool: str): async def get_client_from_prompt(self, prompt: str): for name, client in self.get_clients(): - # client cannot have prompts if it is not connected if not client.session: continue diff --git a/mcp_bridge/mcp_clients/StdioClient.py b/mcp_bridge/mcp_clients/StdioClient.py index 1923f40..189ace7 100644 --- a/mcp_bridge/mcp_clients/StdioClient.py +++ b/mcp_bridge/mcp_clients/StdioClient.py @@ -12,6 +12,7 @@ # Keywords to identify virtual environment variables venv_keywords = ["CONDA", "VIRTUAL", "PYTHON"] + class StdioClient(GenericMcpClient): config: StdioServerParameters @@ -25,7 +26,8 @@ def __init__(self, name: str, config: StdioServerParameters) -> None: env = dict(os.environ.copy()) env = { - key: value for key, value in env.items() + key: value + for key, value in env.items() if not any(key.startswith(keyword) for keyword in venv_keywords) } diff --git a/mcp_bridge/mcp_clients/session.py b/mcp_bridge/mcp_clients/session.py index 56d1f94..ca7cc8f 100644 --- a/mcp_bridge/mcp_clients/session.py +++ b/mcp_bridge/mcp_clients/session.py @@ -25,7 +25,6 @@ class McpClientSession( types.ServerNotification, ] ): - def __init__( self, read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], @@ -75,11 +74,11 @@ async def initialize(self) -> types.InitializeResult: capabilities=types.ClientCapabilities( sampling=types.SamplingCapability(), experimental=None, - roots=types.RootsCapability( - listChanged=True - ), + roots=types.RootsCapability(listChanged=True), + ), + clientInfo=types.Implementation( + name="MCP-Bridge", version=version ), - clientInfo=types.Implementation(name="MCP-Bridge", version=version), ), ) ), @@ -273,9 +272,10 @@ async def _received_request( client_response = types.ClientResult(**response.model_dump()) await responder.respond(client_response) - async def sample(self, params: types.CreateMessageRequestParams) -> types.CreateMessageResult: + async def sample( + self, params: types.CreateMessageRequestParams + ) -> types.CreateMessageResult: logger.info("got sampling request from mcp server") resp = await handle_sampling_message(params) logger.info("finished sampling request from mcp server") return resp - \ No newline at end of file diff --git a/mcp_bridge/models/chatCompletionStreamResponse.py b/mcp_bridge/models/chatCompletionStreamResponse.py index a3bcefc..daee7fb 100644 --- a/mcp_bridge/models/chatCompletionStreamResponse.py +++ b/mcp_bridge/models/chatCompletionStreamResponse.py @@ -15,7 +15,7 @@ class Choice(BaseModel): class SSEData(BaseModel): - id: str + id: str = "default-id" object: str created: int model: str diff --git a/mcp_bridge/models/upstream_error.py b/mcp_bridge/models/upstream_error.py new file mode 100644 index 0000000..67d144a --- /dev/null +++ b/mcp_bridge/models/upstream_error.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field +from typing import Annotated + + +class UpstreamErrorDetails(BaseModel): + message: Annotated[str, Field(description="Error message")] = "An upstream error occurred" + code : Annotated[str | None, Field(description="Error code")] = "UPSTREAM_ERROR" + +class UpstreamError(BaseModel): + error: Annotated[UpstreamErrorDetails, Field(description="Error details")] \ No newline at end of file diff --git a/mcp_bridge/openai_clients/__init__.py b/mcp_bridge/openai_clients/__init__.py index b47def7..c7b4a80 100644 --- a/mcp_bridge/openai_clients/__init__.py +++ b/mcp_bridge/openai_clients/__init__.py @@ -1,6 +1,5 @@ -from .genericHttpxClient import client from .completion import completions from .chatCompletion import chat_completions from .streamChatCompletion import streaming_chat_completions -__all__ = ["client", "completions", "chat_completions", "streaming_chat_completions"] +__all__ = ["completions", "chat_completions", "streaming_chat_completions"] diff --git a/mcp_bridge/openai_clients/chatCompletion.py b/mcp_bridge/openai_clients/chatCompletion.py index a43d7f6..f67cf60 100644 --- a/mcp_bridge/openai_clients/chatCompletion.py +++ b/mcp_bridge/openai_clients/chatCompletion.py @@ -1,16 +1,41 @@ +import secrets +import time +from turtle import st from lmos_openai_types import ( + ChatCompletionResponseMessage, + Choice1, CreateChatCompletionRequest, CreateChatCompletionResponse, ChatCompletionRequestMessage, + FinishReason1, ) -from .utils import call_tool, chat_completion_add_tools -from .genericHttpxClient import client -from mcp_bridge.mcp_clients.McpClientManager import ClientManager -from mcp_bridge.tool_mappers import mcp2openai +from .utils import call_tool, chat_completion_add_tools, validate_if_json_object_parsable, json_pretty_print +from mcp_bridge.http_clients import get_client +from mcp_bridge.inference_engine_mappers.chat.requester import chat_completion_requester +from mcp_bridge.inference_engine_mappers.chat.responder import chat_completion_responder from loguru import logger import json +def format_error_as_chat_completion(message: str) -> CreateChatCompletionResponse: + return CreateChatCompletionResponse.model_validate( + { + "model": "MCP-Bridge", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "content": message, + "role": "assistant", + } + } + ], + "id": secrets.token_hex(16), + "created": int(time.time()), + "object": "chat.completion", + } + ) async def chat_completions( request: CreateChatCompletionRequest, @@ -22,22 +47,30 @@ async def chat_completions( while True: # logger.debug(request.model_dump_json()) - text = ( - await client.post( + response = await get_client().post( "/chat/completions", - #content=request.model_dump_json( - # exclude_defaults=True, exclude_none=True, exclude_unset=True - #), - json=request.model_dump(exclude_defaults=True, exclude_none=True, exclude_unset=True), + json=chat_completion_requester(request), ) - ).text + text = response.text logger.debug(text) try: - response = CreateChatCompletionResponse.model_validate_json(text) + response = chat_completion_responder(json.loads(text)) except Exception as e: logger.error(f"Error parsing response: {text}") logger.error(e) - return + + # openrouter returns a json error message + try : + response = json.loads(text) + return format_error_as_chat_completion(f"Upstream error: {response['error']['message']}") + except Exception: + pass + + return format_error_as_chat_completion(f"Error parsing response: {text}") + + if not response.choices: + logger.error("no choices found in response") + return format_error_as_chat_completion("no choices found in response") msg = response.choices[0].message msg = ChatCompletionRequestMessage( @@ -53,11 +86,22 @@ async def chat_completions( return response logger.debug("tool calls found") + + logger.debug("clearing tool contexts to prevent tool call loops") + request.tools = None + for tool_call in response.choices[0].message.tool_calls.root: logger.debug( - f"tool call: {tool_call.function.name} arguments: {json.loads(tool_call.function.arguments)}" + f"tool call: {tool_call.function.name}" ) + if validate_if_json_object_parsable(tool_call.function.arguments): + logger.debug(f"arguments:\n{json_pretty_print(tool_call.function.arguments)}") + else: + logger.debug("non-json arguments given: %s" % tool_call.function.arguments) + logger.debug("unable to parse tool call argument as json. skipping...") + continue + # FIXME: this can probably be done in parallel using asyncio gather tool_call_result = await call_tool( tool_call.function.name, tool_call.function.arguments @@ -84,9 +128,9 @@ async def chat_completions( { "role": "tool", "content": tools_content, - "tool_call_id": tool_call.id, + "tool_call_id": tool_call.id or secrets.token_hex(16), } ) ) - logger.debug("sending next iteration of chat completion request") + logger.debug("sending next iteration of chat completion request") diff --git a/mcp_bridge/openai_clients/completion.py b/mcp_bridge/openai_clients/completion.py index 42d06f1..8377699 100644 --- a/mcp_bridge/openai_clients/completion.py +++ b/mcp_bridge/openai_clients/completion.py @@ -1,11 +1,11 @@ from lmos_openai_types import CreateCompletionRequest -from .genericHttpxClient import client +from mcp_bridge.http_clients import get_client async def completions(request: CreateCompletionRequest) -> dict: """performs a completion using the inference server""" - response = await client.post( + response = await get_client().post( "/completions", json=request.model_dump( exclude_defaults=True, exclude_none=True, exclude_unset=True diff --git a/mcp_bridge/openai_clients/genericHttpxClient.py b/mcp_bridge/openai_clients/genericHttpxClient.py deleted file mode 100644 index a89506f..0000000 --- a/mcp_bridge/openai_clients/genericHttpxClient.py +++ /dev/null @@ -1,8 +0,0 @@ -from httpx import AsyncClient -from mcp_bridge.config import config - -client: AsyncClient = AsyncClient( - base_url=config.inference_server.base_url, - headers={"Authorization": f"Bearer {config.inference_server.api_key}", "Content-Type": "application/json"}, - timeout=10000, -) diff --git a/mcp_bridge/openai_clients/streamChatCompletion.py b/mcp_bridge/openai_clients/streamChatCompletion.py index 67c7f67..1862cb0 100644 --- a/mcp_bridge/openai_clients/streamChatCompletion.py +++ b/mcp_bridge/openai_clients/streamChatCompletion.py @@ -1,36 +1,69 @@ +import datetime import json +import os +import secrets +import time +import traceback from typing import Optional -from fastapi import HTTPException +from secrets import token_hex from lmos_openai_types import ( ChatCompletionMessageToolCall, ChatCompletionRequestMessage, CreateChatCompletionRequest, CreateChatCompletionStreamResponse, Function1, + FinishReason1, ) -from .utils import call_tool, chat_completion_add_tools -from mcp_bridge.models import SSEData -from .genericHttpxClient import client -from mcp_bridge.mcp_clients.McpClientManager import ClientManager -from mcp_bridge.tool_mappers import mcp2openai + +from mcp_bridge.inference_engine_mappers.chat.requester import chat_completion_requester +from mcp_bridge.inference_engine_mappers.chat.stream_responder import ( + chat_completion_stream_responder, +) +from .utils import ( + call_tool, + chat_completion_add_tools, + json_pretty_print, + salvage_parsable_json_object, + validate_if_json_object_parsable, +) +from mcp_bridge.models import SSEData, upstream_error +from mcp_bridge.http_clients import get_client from loguru import logger from httpx_sse import aconnect_sse -from sse_starlette.sse import EventSourceResponse, ServerSentEvent +from sse_starlette.sse import EventSourceResponse +from sse_starlette.event import ServerSentEvent async def streaming_chat_completions(request: CreateChatCompletionRequest): # raise NotImplementedError("Streaming Chat Completion is not supported") - try: - return EventSourceResponse( - content=chat_completions(request), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache"}, - ) - - except Exception as e: - logger.error(e) + return EventSourceResponse( + content=chat_completions(request), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache"}, + ) + + +def format_error_as_sse(message: str) -> str: + return SSEData.model_validate( + { + "id": f"error-{token_hex(16)}", + "provider": "MCP-Bridge", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "MCP-Bridge", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "content": message, + }, + } + ], + } + ).model_dump_json() async def chat_completions(request: CreateChatCompletionRequest): @@ -46,36 +79,62 @@ async def chat_completions(request: CreateChatCompletionRequest): # exclude_defaults=True, exclude_none=True, exclude_unset=True # ) - json_data = json.dumps(request.model_dump( - exclude_defaults=True, exclude_none=True, exclude_unset=True - )) + json_data = json_pretty_print(chat_completion_requester(request)) - # logger.debug(json_data) + logger.debug("Request JSON:\n%s" % json_data) last: Optional[CreateChatCompletionStreamResponse] = None # last message tool_call_name: str = "" tool_call_json: str = "" + has_tool_calls: bool = False should_forward: bool = True response_content: str = "" tool_call_id: str = "" async with aconnect_sse( - client, "post", "/chat/completions", content=json_data + get_client(), "post", "/chat/completions", content=json_data ) as event_source: - + logger.debug(event_source.response.status_code) + # check if the content type is correct because the aiter_sse method # will raise an exception if the content type is not correct if "Content-Type" in event_source.response.headers: content_type = event_source.response.headers["Content-Type"] + if "application/json" in content_type: + logger.error(f"Unexpected Content-Type: {content_type}") + error_data = await event_source.response.aread() + # logger.error(f"Request URL: {event_source.response.url}") + # logger.error(f"Response Status: {event_source.response.status_code}") + # logger.error(f"Response Data: {error_data.decode(event_source.response.encoding or 'utf-8')}") + # raise HTTPException(status_code=500, detail="Unexpected Content-Type") + data = json.loads( + error_data.decode(event_source.response.encoding or "utf-8") + ) + if message := data.get("error", {}).get("message"): + logger.error(f"Upstream error: {message}") + yield format_error_as_sse(message) + yield [ + "DONE" + ] # ServerSentEvent(event="message", data="[DONE]", id=None, retry=None) + return + if "text/event-stream" not in content_type: logger.error(f"Unexpected Content-Type: {content_type}") error_data = await event_source.response.aread() logger.error(f"Request URL: {event_source.response.url}") logger.error(f"Request Data: {json_data}") - logger.error(f"Response Status: {event_source.response.status_code}") - logger.error(f"Response Data: {error_data.decode(event_source.response.encoding or 'utf-8')}") - raise HTTPException(status_code=500, detail="Unexpected Content-Type") + logger.error( + f"Response Status: {event_source.response.status_code}" + ) + logger.error( + f"Response Data: {error_data.decode(event_source.response.encoding or 'utf-8')}" + ) + yield format_error_as_sse("Upsteam error: Unexpected Content-Type") + yield ServerSentEvent( + event="message", data="[DONE]", id=None, retry=None + ) + return # iterate over the SSE stream async for sse in event_source.aiter_sse(): @@ -93,19 +152,26 @@ async def chat_completions(request: CreateChatCompletionRequest): logger.debug("inference serverstream done") break - # for some reason openrouter uses uppercase for finish_reason + # try to parse the data as json, if this fails we assume it is an error message + # if parsing fails we send the error message to the client + dict_data = json.loads(data) try: - data['choices'][0]['finish_reason'] = data['choices'][0]['finish_reason'].lower() # type: ignore - except Exception as e: - logger.debug(f"failed to lowercase finish_reason: {e}") - - try: - parsed_data = CreateChatCompletionStreamResponse.model_validate_json( - data - ) - except Exception as e: + parsed_data = chat_completion_stream_responder(dict_data) + except Exception: logger.debug(data) - raise e + try: + parsed_error_data = upstream_error.UpstreamError.model_validate_json(data) + yield format_error_as_sse(parsed_error_data.error.message) + except Exception: + yield format_error_as_sse(f"Error parsing response: {json.loads(data)}") + + yield ServerSentEvent(event="message", data="[DONE]", id=None, retry=None) + return + + # handle empty response (usually caused by "usage" reporting) + if len(parsed_data.choices) == 0: + logger.debug("no choices found in response") + continue # add the delta to the response content content = parsed_data.choices[0].delta.content @@ -121,6 +187,9 @@ async def chat_completions(request: CreateChatCompletionRequest): fully_done = True else: should_forward = False + + if parsed_data.choices[0].finish_reason.value == "tool_calls": + has_tool_calls = True # this manages the incoming tool call schema # most of this is assertions to please mypy @@ -134,6 +203,8 @@ async def chat_completions(request: CreateChatCompletionRequest): name = name if name is not None else "" tool_call_name = name if tool_call_name == "" else tool_call_name + logger.debug(f"ARGS: {parsed_data.choices[0].delta.tool_calls[0].function.arguments}") + call_id = parsed_data.choices[0].delta.tool_calls[0].id call_id = call_id if call_id is not None else "" tool_call_id = id if tool_call_id == "" else tool_call_id @@ -151,13 +222,40 @@ async def chat_completions(request: CreateChatCompletionRequest): # save the last message last = parsed_data + # perform early stopping on parsable tool_call_json + if tool_call_json: + if tool_call_json.strip().startswith("{"): + if validate_if_json_object_parsable(tool_call_json): + logger.debug( + f"tool call json '{tool_call_json}' is parsable now." + ) + logger.debug("exiting message receive loop") + last.choices[0].finish_reason = FinishReason1.tool_calls + break + salvaged_json_object = salvage_parsable_json_object( + tool_call_json + ) + if salvaged_json_object: + tool_call_json = salvaged_json_object + logger.debug( + f"tool call json '{tool_call_json}' is salvagable now." + ) + logger.debug("salvaged json content:", tool_call_json) + logger.debug("exiting message receive loop") + last.choices[0].finish_reason = FinishReason1.tool_calls + break + # ideally we should check this properly assert last is not None - assert last.choices[0].finish_reason is not None - if last.choices[0].finish_reason.value in ["stop", "length"]: - logger.debug("no tool calls found") - fully_done = True + if last.choices[0].finish_reason: + if last.choices[0].finish_reason.value in ["stop", "length"]: + logger.debug("no tool calls found") + fully_done = True + continue + + if last.choices[0].finish_reason is None and not has_tool_calls: + logger.debug("no finish reason found") continue logger.debug("tool calls found") @@ -165,6 +263,12 @@ async def chat_completions(request: CreateChatCompletionRequest): f"{tool_call_name=} {tool_call_json=}" ) # this should not be error but its easier to debug + logger.debug("clearing tool contexts to prevent tool call loops") + request.tools = None + + if tool_call_id is None or tool_call_id == "": + tool_call_id = secrets.token_hex(16) + # add received message to the history msg = ChatCompletionRequestMessage( role="assistant", diff --git a/mcp_bridge/openai_clients/streamCompletion.py b/mcp_bridge/openai_clients/streamCompletion.py deleted file mode 100644 index e69de29..0000000 diff --git a/mcp_bridge/openai_clients/utils.py b/mcp_bridge/openai_clients/utils.py index 58c269b..010974c 100644 --- a/mcp_bridge/openai_clients/utils.py +++ b/mcp_bridge/openai_clients/utils.py @@ -3,24 +3,53 @@ from lmos_openai_types import CreateChatCompletionRequest import mcp.types import json +import traceback from mcp_bridge.mcp_clients.McpClientManager import ClientManager from mcp_bridge.tool_mappers import mcp2openai +def json_pretty_print(obj) -> str: + if type(obj) == bytes: + obj = obj.decode() + if type(obj) == str: + obj = json.loads(obj) + ret = json.dumps(obj, indent=4, ensure_ascii=False) + return ret + +def validate_if_json_object_parsable(content: str): + try: + json.loads(content) + return True + except ValueError: + return False + + +def salvage_parsable_json_object(content: str): + content = content.strip() + for i in range(0, len(content)): + snippet = content[: len(content) - i] + if validate_if_json_object_parsable(snippet): + return snippet + async def chat_completion_add_tools(request: CreateChatCompletionRequest): request.tools = [] + logger.info("adding tools to request") for _, session in ClientManager.get_clients(): # if session is None, then the client is not running if session.session is None: - logger.error(f"session is `None` for {session.name}") + logger.error(f"session is `None` for {session.name}") # Date:2025/01/25 why not running? continue - + logger.debug(f"session ready for {session.name}") tools = await session.session.list_tools() for tool in tools.tools: request.tools.append(mcp2openai(tool)) - + + if request.tools == []: + logger.info("this request loads no tools") + # raise Exception("no tools found. unable to initiate chat completion.") + request.tools = None return request @@ -42,9 +71,10 @@ async def call_tool( return None try: - tool_call_args = json.loads(tool_call_json) + tool_call_args = json.loads(tool_call_json) # Date: 2025/01/26 cannot load this tool call json? except json.JSONDecodeError: logger.error(f"failed to decode json for {tool_call_name}") + traceback.print_exc() return None return await session.call_tool(tool_call_name, tool_call_args, timeout) diff --git a/mcp_bridge/sampling/modelSelector.py b/mcp_bridge/sampling/modelSelector.py index ea70be7..d2e5731 100644 --- a/mcp_bridge/sampling/modelSelector.py +++ b/mcp_bridge/sampling/modelSelector.py @@ -4,21 +4,29 @@ from mcp_bridge.config import config + def euclidean_distance(point1, point2): """ Calculates the Euclidean distance between two points, ignoring None values. """ - valid_dimensions = [(p1, p2) for p1, p2 in zip(point1, point2) if p1 is not None and p2 is not None] + valid_dimensions = [ + (p1, p2) for p1, p2 in zip(point1, point2) if p1 is not None and p2 is not None + ] if not valid_dimensions: # No valid dimensions to compare - return float('inf') - + return float("inf") + return math.sqrt(sum((p1 - p2) ** 2 for p1, p2 in valid_dimensions)) + def find_best_model(preferences: ModelPreferences): distance = math.inf preffered_model = None - preference_points = (preferences.intelligencePriority, preferences.speedPriority, preferences.costPriority) + preference_points = ( + preferences.intelligencePriority, + preferences.speedPriority, + preferences.costPriority, + ) if preference_points == (None, None, None): return config.sampling.models[0] @@ -29,8 +37,8 @@ def find_best_model(preferences: ModelPreferences): if model_distance < distance: distance = model_distance preffered_model = model - + if preffered_model is None: preffered_model = config.sampling.models[0] - - return preffered_model \ No newline at end of file + + return preffered_model diff --git a/mcp_bridge/sampling/sampler.py b/mcp_bridge/sampling/sampler.py index 96745bd..439f3a6 100644 --- a/mcp_bridge/sampling/sampler.py +++ b/mcp_bridge/sampling/sampler.py @@ -1,38 +1,45 @@ +import json from loguru import logger from mcp import SamplingMessage import mcp.types as types -from lmos_openai_types import CreateChatCompletionResponse from mcp.types import CreateMessageRequestParams, CreateMessageResult from mcp_bridge.config import config -from mcp_bridge.openai_clients.genericHttpxClient import client +from mcp_bridge.http_clients import get_client from mcp_bridge.sampling.modelSelector import find_best_model +from mcp_bridge.inference_engine_mappers.chat.generic import chat_completion_generic_response + def make_message(x: SamplingMessage): if x.content.type == "text": return { "role": x.role, - "content": [{ - "type": "text", - "text": x.content.text, - }] + "content": [ + { + "type": "text", + "text": x.content.text, + } + ], } if x.content.type == "image": return { "role": x.role, - "content": [{ - "type": "image", - "image_url": x.content.data, - }] + "content": [ + { + "type": "image", + "image_url": x.content.data, + } + ], } + async def handle_sampling_message( message: CreateMessageRequestParams, ) -> CreateMessageResult: """perform sampling""" logger.debug(f"sampling message: {message.modelPreferences}") - + # select model model = config.sampling.models[0] if message.modelPreferences is not None: @@ -52,7 +59,7 @@ async def handle_sampling_message( logger.debug(request) - resp = await client.post( + resp = await get_client().post( "/chat/completions", json=request, timeout=config.sampling.timeout, @@ -62,7 +69,7 @@ async def handle_sampling_message( text = resp.text logger.debug(text) - response = CreateChatCompletionResponse.model_validate_json(text) + response = chat_completion_generic_response(json.loads(text)) logger.debug("sampling request received from endpoint")