diff --git a/.vscode/launch.json b/.vscode/launch.json index 977eeda..2dcb99c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,6 +10,7 @@ "request": "launch", "django": true, "module": "mcp_bridge.main", + "pythonArgs": ["-Xutf8"] } ] -} \ No newline at end of file +} diff --git a/mcp_bridge/openai_clients/streamChatCompletion.py b/mcp_bridge/openai_clients/streamChatCompletion.py index 67c7f67..e902b46 100644 --- a/mcp_bridge/openai_clients/streamChatCompletion.py +++ b/mcp_bridge/openai_clients/streamChatCompletion.py @@ -7,6 +7,9 @@ CreateChatCompletionRequest, CreateChatCompletionStreamResponse, Function1, + FinishReason1, + ChatCompletionToolChoiceOption1, + ChatCompletionToolChoiceOption, ) from .utils import call_tool, chat_completion_add_tools from mcp_bridge.models import SSEData @@ -15,8 +18,12 @@ from mcp_bridge.tool_mappers import mcp2openai from loguru import logger from httpx_sse import aconnect_sse +import datetime +import os from sse_starlette.sse import EventSourceResponse, ServerSentEvent +import json +import traceback async def streaming_chat_completions(request: CreateChatCompletionRequest): @@ -33,24 +40,52 @@ async def streaming_chat_completions(request: CreateChatCompletionRequest): logger.error(e) +def validate_if_json_object_parsable(content: str): + try: + json.loads(content) + return True + except ValueError: + return False + + +def salvage_parsable_json_object(content: str): + content = content.strip() + for i in range(0, len(content)): + snippet = content[: len(content) - i] + if validate_if_json_object_parsable(snippet): + return snippet + + async def chat_completions(request: CreateChatCompletionRequest): """performs a chat completion using the inference server""" request.stream = True - request = await chat_completion_add_tools(request) + request = await chat_completion_add_tools( + request + ) # Date: 2025/01/27 ChatMCP clear tools after first tool call. fully_done = False while not fully_done: # json_data = request.model_dump_json( # exclude_defaults=True, exclude_none=True, exclude_unset=True # ) + if request.tools: + request.tool_choice = ChatCompletionToolChoiceOption( + root=ChatCompletionToolChoiceOption1.auto + ) - json_data = json.dumps(request.model_dump( - exclude_defaults=True, exclude_none=True, exclude_unset=True - )) + json_data = json.dumps( + request.model_dump( + exclude_defaults=True, + exclude_none=True, + exclude_unset=True, + ), + indent=4, + ensure_ascii=False, + ) - # logger.debug(json_data) + logger.debug("Request JSON:\n%s" % json_data) # empty? last: Optional[CreateChatCompletionStreamResponse] = None # last message @@ -63,19 +98,40 @@ async def chat_completions(request: CreateChatCompletionRequest): async with aconnect_sse( client, "post", "/chat/completions", content=json_data ) as event_source: - + # check if the content type is correct because the aiter_sse method # will raise an exception if the content type is not correct - if "Content-Type" in event_source.response.headers: + if "Content-Type" in event_source.response.headers: # error here. content_type = event_source.response.headers["Content-Type"] if "text/event-stream" not in content_type: logger.error(f"Unexpected Content-Type: {content_type}") error_data = await event_source.response.aread() logger.error(f"Request URL: {event_source.response.url}") - logger.error(f"Request Data: {json_data}") - logger.error(f"Response Status: {event_source.response.status_code}") - logger.error(f"Response Data: {error_data.decode(event_source.response.encoding or 'utf-8')}") - raise HTTPException(status_code=500, detail="Unexpected Content-Type") + log_dir = os.path.join(os.getcwd(), "logs") + if not os.path.exists(log_dir): + os.makedirs(log_dir) + request_data_path = f"{log_dir}/request_data_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + request_data_formatted = json.dumps( + json.loads(json_data), indent=4, ensure_ascii=False + ) + with open(request_data_path, "w+") as f: + f.write(request_data_formatted) + logger.error(f"Request Data saved to: {request_data_path}") + logger.error(f"Request Data:\n{request_data_formatted}") + logger.error( + f"Response Status: {event_source.response.status_code}" + ) + error_data_decoded = error_data.decode( + event_source.response.encoding or "utf-8" + ) + error_data_path = f"{log_dir}/error_data_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + logger.error(f"Response Data saved to: {error_data_path}") + logger.error(f"Response Data:\n{error_data_decoded}") + with open(error_data_path, "w+") as f: + f.write(error_data_decoded) + raise HTTPException( + status_code=500, detail="Unexpected Content-Type" + ) # iterate over the SSE stream async for sse in event_source.aiter_sse(): @@ -95,18 +151,26 @@ async def chat_completions(request: CreateChatCompletionRequest): # for some reason openrouter uses uppercase for finish_reason try: - data['choices'][0]['finish_reason'] = data['choices'][0]['finish_reason'].lower() # type: ignore + mjson_data = json.loads(data) + + # Date: 2025/01/26 failed to lowercase finish_reason: string indices must be integers, not 'str' + if mjson_data["choices"][0].keys().__contains__("finish_reason"): # type: ignore + mjson_data["choices"][0]["finish_reason"] = mjson_data["choices"][0]["finish_reason"].lower() # type: ignore + + data = json.dumps(mjson_data, ensure_ascii=False) except Exception as e: + traceback.print_exc() logger.debug(f"failed to lowercase finish_reason: {e}") try: - parsed_data = CreateChatCompletionStreamResponse.model_validate_json( - data + parsed_data = ( + CreateChatCompletionStreamResponse.model_validate_json(data) ) except Exception as e: logger.debug(data) raise e + # add the delta to the response content content = parsed_data.choices[0].delta.content content = content if content is not None else "" @@ -139,7 +203,9 @@ async def chat_completions(request: CreateChatCompletionRequest): tool_call_id = id if tool_call_id == "" else tool_call_id arg = parsed_data.choices[0].delta.tool_calls[0].function.arguments + tool_call_json += arg if arg is not None else "" + # Date: 2025/01/26 validate the tool call json. # forward SSE messages to the client logger.debug(f"{should_forward=}") @@ -151,6 +217,27 @@ async def chat_completions(request: CreateChatCompletionRequest): # save the last message last = parsed_data + if tool_call_json: + if tool_call_json.strip().startswith("{"): + if validate_if_json_object_parsable(tool_call_json): + logger.debug( + f"tool call json '{tool_call_json}' is parsable now." + ) + logger.debug("exiting message receive loop") + last.choices[0].finish_reason = FinishReason1.tool_calls + break + salvaged_json_object = salvage_parsable_json_object( + tool_call_json + ) + if salvaged_json_object: + tool_call_json = salvaged_json_object + logger.debug( + f"tool call json '{tool_call_json}' is salvagable now." + ) + logger.debug("salvaged json content:", tool_call_json) + logger.debug("exiting message receive loop") + last.choices[0].finish_reason = FinishReason1.tool_calls + break # ideally we should check this properly assert last is not None assert last.choices[0].finish_reason is not None @@ -165,6 +252,9 @@ async def chat_completions(request: CreateChatCompletionRequest): f"{tool_call_name=} {tool_call_json=}" ) # this should not be error but its easier to debug + logger.debug("clearing tool contexts to prevent tool call loops") + request.tools = None + # add received message to the history msg = ChatCompletionRequestMessage( role="assistant", @@ -181,6 +271,7 @@ async def chat_completions(request: CreateChatCompletionRequest): #### MOST OF THIS IS COPY PASTED FROM CHAT_COMPLETIONS # FIXME: this can probably be done in parallel using asyncio gather + # Date: 2025/01/26 decoding error? tool_call_result = await call_tool(tool_call_name, tool_call_json) if tool_call_result is None: continue @@ -207,6 +298,14 @@ async def chat_completions(request: CreateChatCompletionRequest): ) ) + # Date: 2025/01/26 crucial! we have to ensure the llm does not end up with infinite loop. + + # request.messages.append( + # ChatCompletionRequestMessage.model_validate( + # {"role": "user", "content": "Do you consider you have done enough tool calls? If not, please continue the rest of the tool calls. If yes, please respond to the user and end the conversation."} + # ) + # ) + logger.debug("sending next iteration of chat completion request") # when done, send the final event diff --git a/mcp_bridge/openai_clients/utils.py b/mcp_bridge/openai_clients/utils.py index 58c269b..57c1be7 100644 --- a/mcp_bridge/openai_clients/utils.py +++ b/mcp_bridge/openai_clients/utils.py @@ -3,6 +3,7 @@ from lmos_openai_types import CreateChatCompletionRequest import mcp.types import json +import traceback from mcp_bridge.mcp_clients.McpClientManager import ClientManager from mcp_bridge.tool_mappers import mcp2openai @@ -10,17 +11,22 @@ async def chat_completion_add_tools(request: CreateChatCompletionRequest): request.tools = [] + logger.info("adding tools to request") for _, session in ClientManager.get_clients(): # if session is None, then the client is not running if session.session is None: - logger.error(f"session is `None` for {session.name}") + logger.error(f"session is `None` for {session.name}") # Date:2025/01/25 why not running? continue - + logger.debug(f"session ready for {session.name}") tools = await session.session.list_tools() for tool in tools.tools: request.tools.append(mcp2openai(tool)) - + + if request.tools == []: + logger.info("this request loads no tools") + # raise Exception("no tools found. unable to initiate chat completion.") + request.tools = None return request @@ -42,9 +48,10 @@ async def call_tool( return None try: - tool_call_args = json.loads(tool_call_json) + tool_call_args = json.loads(tool_call_json) # Date: 2025/01/26 cannot load this tool call json? except json.JSONDecodeError: logger.error(f"failed to decode json for {tool_call_name}") + traceback.print_exc() return None return await session.call_tool(tool_call_name, tool_call_args, timeout)