Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"request": "launch",
"django": true,
"module": "mcp_bridge.main",
"pythonArgs": ["-Xutf8"]
}
]
}
}
15 changes: 13 additions & 2 deletions mcp_bridge/openai_clients/chatCompletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
FinishReason1,
)

from .utils import call_tool, chat_completion_add_tools
from .utils import call_tool, chat_completion_add_tools, validate_if_json_object_parsable, json_pretty_print
from mcp_bridge.http_clients import get_client
from mcp_bridge.inference_engine_mappers.chat.requester import chat_completion_requester
from mcp_bridge.inference_engine_mappers.chat.responder import chat_completion_responder
Expand Down Expand Up @@ -86,11 +86,22 @@ async def chat_completions(
return response

logger.debug("tool calls found")

logger.debug("clearing tool contexts to prevent tool call loops")
request.tools = None

for tool_call in response.choices[0].message.tool_calls.root:
logger.debug(
f"tool call: {tool_call.function.name} arguments: {json.loads(tool_call.function.arguments)}"
f"tool call: {tool_call.function.name}"
)

if validate_if_json_object_parsable(tool):
logger.debug(f"arguments:\n{json_pretty_print(tool_call.function.arguments)}")
else:
logger.debug("non-json arguments given: %s" % tool_call.function.arguments)
logger.debug("unable to parse tool call argument as json. skipping...")
continue

# FIXME: this can probably be done in parallel using asyncio gather
tool_call_result = await call_tool(
tool_call.function.name, tool_call.function.arguments
Expand Down
42 changes: 39 additions & 3 deletions mcp_bridge/openai_clients/streamChatCompletion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import datetime
import json
import os
import time
import traceback
from typing import Optional
from secrets import token_hex
from lmos_openai_types import (
Expand All @@ -8,13 +11,20 @@
CreateChatCompletionRequest,
CreateChatCompletionStreamResponse,
Function1,
FinishReason1,
)

from mcp_bridge.inference_engine_mappers.chat.requester import chat_completion_requester
from mcp_bridge.inference_engine_mappers.chat.stream_responder import (
chat_completion_stream_responder,
)
from .utils import call_tool, chat_completion_add_tools
from .utils import (
call_tool,
chat_completion_add_tools,
json_pretty_print,
salvage_parsable_json_object,
validate_if_json_object_parsable,
)
from mcp_bridge.models import SSEData, upstream_error
from mcp_bridge.http_clients import get_client
from loguru import logger
Expand Down Expand Up @@ -68,9 +78,9 @@ async def chat_completions(request: CreateChatCompletionRequest):
# exclude_defaults=True, exclude_none=True, exclude_unset=True
# )

json_data = json.dumps(chat_completion_requester(request))
json_data = json_pretty_print(chat_completion_requester(request))

# logger.debug(json_data)
logger.debug("Request JSON:\n%s" % json_data)

last: Optional[CreateChatCompletionStreamResponse] = None # last message

Expand Down Expand Up @@ -211,6 +221,29 @@ async def chat_completions(request: CreateChatCompletionRequest):
# save the last message
last = parsed_data

# perform early stopping on parsable tool_call_json
if tool_call_json:
if tool_call_json.strip().startswith("{"):
if validate_if_json_object_parsable(tool_call_json):
logger.debug(
f"tool call json '{tool_call_json}' is parsable now."
)
logger.debug("exiting message receive loop")
last.choices[0].finish_reason = FinishReason1.tool_calls
break
salvaged_json_object = salvage_parsable_json_object(
tool_call_json
)
if salvaged_json_object:
tool_call_json = salvaged_json_object
logger.debug(
f"tool call json '{tool_call_json}' is salvagable now."
)
logger.debug("salvaged json content:", tool_call_json)
logger.debug("exiting message receive loop")
last.choices[0].finish_reason = FinishReason1.tool_calls
break

# ideally we should check this properly
assert last is not None

Expand All @@ -229,6 +262,9 @@ async def chat_completions(request: CreateChatCompletionRequest):
f"{tool_call_name=} {tool_call_json=}"
) # this should not be error but its easier to debug

logger.debug("clearing tool contexts to prevent tool call loops")
request.tools = None

# add received message to the history
msg = ChatCompletionRequestMessage(
role="assistant",
Expand Down
Empty file.
38 changes: 34 additions & 4 deletions mcp_bridge/openai_clients/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,53 @@
from lmos_openai_types import CreateChatCompletionRequest
import mcp.types
import json
import traceback

from mcp_bridge.mcp_clients.McpClientManager import ClientManager
from mcp_bridge.tool_mappers import mcp2openai


def json_pretty_print(obj) -> str:
if type(obj) == bytes:
obj = obj.decode()
if type(obj) == str:
obj = json.loads(obj)
ret = json.dumps(obj, indent=4, ensure_ascii=False)
return ret

def validate_if_json_object_parsable(content: str):
try:
json.loads(content)
return True
except ValueError:
return False


def salvage_parsable_json_object(content: str):
content = content.strip()
for i in range(0, len(content)):
snippet = content[: len(content) - i]
if validate_if_json_object_parsable(snippet):
return snippet

async def chat_completion_add_tools(request: CreateChatCompletionRequest):
request.tools = []
logger.info("adding tools to request")

for _, session in ClientManager.get_clients():
# if session is None, then the client is not running
if session.session is None:
logger.error(f"session is `None` for {session.name}")
logger.error(f"session is `None` for {session.name}") # Date:2025/01/25 why not running?
continue

logger.debug(f"session ready for {session.name}")
tools = await session.session.list_tools()
for tool in tools.tools:
request.tools.append(mcp2openai(tool))


if request.tools == []:
logger.info("this request loads no tools")
# raise Exception("no tools found. unable to initiate chat completion.")
request.tools = None
return request


Expand All @@ -42,9 +71,10 @@ async def call_tool(
return None

try:
tool_call_args = json.loads(tool_call_json)
tool_call_args = json.loads(tool_call_json) # Date: 2025/01/26 cannot load this tool call json?
except json.JSONDecodeError:
logger.error(f"failed to decode json for {tool_call_name}")
traceback.print_exc()
return None

return await session.call_tool(tool_call_name, tool_call_args, timeout)