diff --git a/README.md b/README.md index e09b8506..39974b01 100644 --- a/README.md +++ b/README.md @@ -5,14 +5,15 @@ This repository contains a protocol-level CLI designed to interact with a Model - Protocol-level communication with the MCP Server. - Dynamic tool and resource exploration. - Support for multiple providers and models: - - Providers: OpenAI, Ollama. - - Default models: `gpt-4o-mini` for OpenAI, `qwen2.5-coder` for Ollama. + - Providers: OpenAI, Ollama, Amazon Bedrock + - Default models: `gpt-4o-mini` for OpenAI, `qwen2.5-coder` for Ollama, `Claude-3.5-sonnet`for Amazon Bedrock. ## Prerequisites - Python 3.8 or higher. - Required dependencies (see [Installation](#installation)) - If using ollama you should have ollama installed and running. - If using openai you should have an api key set in your environment variables (OPENAI_API_KEY=yourkey) +- if using Amazon Bedrock you should have an access key and secret access key. ## Installation 1. Clone the repository: @@ -43,11 +44,17 @@ uv run mcp-cli --server sqlite ### Command-line Arguments - `--server`: Specifies the server configuration to use. Required. + - `--config-file`: (Optional) Path to the JSON configuration file. Defaults to `server_config.json`. + - `--provider`: (Optional) Specifies the provider to use (`openai` or `ollama`). Defaults to `openai`. + - `--model`: (Optional) Specifies the model to use. Defaults depend on the provider: - `gpt-4o-mini` for OpenAI. - `llama3.2` for Ollama. + - `claude-3.5-sonnet` ,`claude-3.5-haiku`, `nova-lite`,`nova-pro` for Amazone Bedrock + +- `--aws-region`: Specifies the AWS Region configuration to use. Default to us-east-1. ### Examples Run the client with the default OpenAI provider and model: @@ -62,7 +69,14 @@ Run the client with a specific configuration and Ollama provider: uv run mcp-cli --server sqlite --provider ollama --model llama3.2 ``` +Run the client with Amazone Bedrock provider : + +```bash +uv run mcp-cli --server sqlite --provider amazon --aws-region us-west-2 +``` + ## Interactive Mode + The client supports interactive mode, allowing you to execute commands dynamically. Type `help` for a list of available commands or `quit` to exit the program. ## Supported Commands diff --git a/server_config.json b/server_config.json index eb6bb4ab..4f2d6e94 100644 --- a/server_config.json +++ b/server_config.json @@ -1,8 +1,13 @@ { "mcpServers": { "sqlite": { + "transport": "stdio", "command": "uvx", "args": ["mcp-server-sqlite", "--db-path", "test.db"] + }, + "fetch": { + "transport": "sse", + "endpoint": "http://localhost:3001/sse" } } } diff --git a/src/mcpcli/__main__.py b/src/mcpcli/__main__.py index ff6216ce..26e6d60b 100644 --- a/src/mcpcli/__main__.py +++ b/src/mcpcli/__main__.py @@ -24,7 +24,10 @@ from mcpcli.messages.send_initialize_message import send_initialize from mcpcli.messages.send_call_tool import send_call_tool from mcpcli.messages.send_tools_list import send_tools_list +from mcpcli.transport.sse.sse_client import sse_client +from mcpcli.transport.sse.sse_server_parameters import SSEServerParameters from mcpcli.transport.stdio.stdio_client import stdio_client +from mcpcli.transport.stdio.stdio_server_parameters import StdioServerParameters # Default path for the configuration file DEFAULT_CONFIG_FILE = "server_config.json" @@ -286,19 +289,29 @@ async def run(config_path: str, server_names: List[str], command: str = None) -> # Load server configurations and establish connections for all servers server_streams = [] context_managers = [] + client = None for server_name in server_names: server_params = await load_config(config_path, server_name) - # Establish stdio communication for each server - cm = stdio_client(server_params) - (read_stream, write_stream) = await cm.__aenter__() - context_managers.append(cm) - server_streams.append((read_stream, write_stream)) + # Establish stdio or sse communication for each server + if isinstance(server_params, StdioServerParameters): + cm = stdio_client(server_params) + (read_stream, write_stream) = await cm.__aenter__() + context_managers.append(cm) + server_streams.append((read_stream, write_stream)) - init_result = await send_initialize(read_stream, write_stream) - if not init_result: - print(f"[red]Server initialization failed for {server_name}[/red]") - return + init_result = await send_initialize(read_stream, write_stream) + if not init_result: + print(f"[red]Server initialization failed for {server_name}[/red]") + return + elif isinstance(server_params, SSEServerParameters): + client = sse_client(server_params.endpoint) + (read_stream, write_stream) = await client.__aenter__() + context_managers.append(client) + server_streams.append((read_stream, write_stream)) + + else: + raise ValueError("Server transport not supported") try: if command: @@ -340,23 +353,32 @@ def cli_main(): parser.add_argument( "--provider", - choices=["openai", "ollama"], + choices=["openai", "ollama","amazon"], default="openai", help="LLM provider to use. Defaults to 'openai'.", ) parser.add_argument( "--model", - help=("Model to use. Defaults to 'gpt-4o-mini' for 'openai' and 'qwen2.5-coder' for 'ollama'."), + help=("Model to use. Defaults to 'gpt-4o-mini' for 'openai' and 'qwen2.5-coder' for 'ollama', 'Claude-3-5-sonnet' for 'amazon'."), + ) + + parser.add_argument( + "--aws-region", + default="us-east-1", + help=("AWS region to use. Defaults to 'us-east-1'."), ) args = parser.parse_args() model = args.model or ( - "gpt-4o-mini" if args.provider == "openai" else "qwen2.5-coder" + "gpt-4o-mini" if args.provider == "openai" + else "claude-3.5-sonnet" if args.provider == "amazon" + else "qwen2.5-coder" ) os.environ["LLM_PROVIDER"] = args.provider os.environ["LLM_MODEL"] = model + os.environ["AWS_REGION"] = args.aws_region try: result = anyio.run(run, args.config_file, args.servers, args.command) diff --git a/src/mcpcli/chat_handler.py b/src/mcpcli/chat_handler.py index 3f45d335..9b5c9cfd 100644 --- a/src/mcpcli/chat_handler.py +++ b/src/mcpcli/chat_handler.py @@ -1,5 +1,6 @@ # chat_handler.py import json +from datetime import datetime from rich import print from rich.markdown import Markdown @@ -66,6 +67,74 @@ async def process_conversation( response_content = completion.get("response", "No response") tool_calls = completion.get("tool_calls", []) + # Save assistant response with additional metadata if provider is amazon + if client.provider == "amazon": + content = [] + if response_content: + content.append({"text": response_content}) + + if tool_calls: + for tool_call in tool_calls: + if hasattr(tool_call, "function"): + tool_use = { + "toolUse": { + "toolUseId": tool_call.id if hasattr(tool_call, 'id') else None, + "name": tool_call.function.name.replace("-", "_"), + "input": json.loads(tool_call.function.arguments) + } + } + content.append(tool_use) + elif isinstance(tool_call, dict) and "function" in tool_call: + tool_use = { + "toolUse": { + "toolUseId": tool_call.get("id"), + "name": tool_call["function"]["name"].replace("-", "_"), + "input": json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"] + } + } + content.append(tool_use) + + assistant_message = { + "role": "assistant", + "content": content, + "metadata": { + "timestamp": datetime.now().isoformat(), + "conversation_id": completion.get("conversation_id", ""), + } + } + conversation_history.append(assistant_message) + else: + content = [] + if response_content: + content.append({"text": response_content}) + + if tool_calls: + for tool_call in tool_calls: + if hasattr(tool_call, "function"): + tool_use = { + "toolUse": { + "toolUseId": f"tooluse_{datetime.now().strftime('%Y%m%d%H%M%S')}", + "name": tool_call.function.name.replace("-", "_"), + "input": json.loads(tool_call.function.arguments) + } + } + content.append(tool_use) + elif isinstance(tool_call, dict) and "function" in tool_call: + tool_use = { + "toolUse": { + "toolUseId": f"tooluse_{datetime.now().strftime('%Y%m%d%H%M%S')}", + "name": tool_call["function"]["name"].replace("-", "_"), + "input": json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"] + } + } + content.append(tool_use) + + assistant_message = { + "role": "assistant", + "content": content + } + #conversation_history.append(assistant_message) + if tool_calls: for tool_call in tool_calls: # Extract tool_name and raw_arguments as before @@ -85,10 +154,8 @@ async def process_conversation( try: raw_arguments = json.loads(raw_arguments) except json.JSONDecodeError: - # If it's not valid JSON, just display as is pass - # Now raw_arguments should be a dict or something we can pretty-print as JSON tool_args_str = json.dumps(raw_arguments, indent=2) tool_md = f"**Tool Call:** {tool_name}\n\n```json\n{tool_args_str}\n```" @@ -106,7 +173,6 @@ async def process_conversation( print( Panel(Markdown(assistant_panel_text), style="bold blue", title="Assistant") ) - conversation_history.append({"role": "assistant", "content": response_content}) break diff --git a/src/mcpcli/config.py b/src/mcpcli/config.py index 192d6dc8..864f2814 100644 --- a/src/mcpcli/config.py +++ b/src/mcpcli/config.py @@ -2,10 +2,11 @@ import json import logging +from mcpcli.transport.sse.sse_server_parameters import SSEServerParameters from mcpcli.transport.stdio.stdio_server_parameters import StdioServerParameters -async def load_config(config_path: str, server_name: str) -> StdioServerParameters: +async def load_config(config_path: str, server_name: str) -> StdioServerParameters|SSEServerParameters: """Load the server configuration from a JSON file.""" try: # debug @@ -17,23 +18,44 @@ async def load_config(config_path: str, server_name: str) -> StdioServerParamete # Retrieve the server configuration server_config = config.get("mcpServers", {}).get(server_name) + + if not server_config: error_msg = f"Server '{server_name}' not found in configuration file." logging.error(error_msg) raise ValueError(error_msg) + + if "transport" not in server_config: + if "command" in server_config: + server_config["transport"] = "stdio" + elif "endpoint" in server_config: + server_config["transport"] = "sse" + else: + error_msg = f"Server transport not found in configuration file." + logging.error(error_msg) + raise ValueError(error_msg) # Construct the server parameters - result = StdioServerParameters( - command=server_config["command"], - args=server_config.get("args", []), - env=server_config.get("env"), - ) - - # debug - logging.debug( - f"Loaded config: command='{result.command}', args={result.args}, env={result.env}" - ) + if server_config["transport"] == "stdio": + result = StdioServerParameters( + command=server_config["command"], + args=server_config.get("args", []), + env=server_config.get("env"), + ) + # debug + logging.debug( + f"Loaded config: command='{result.command}', args={result.args}, env={result.env}" + ) + elif server_config["transport"] == "sse": + result = SSEServerParameters( + endpoint=server_config["endpoint"], + ) + else: + error_msg = f"Server transport '{server_config['transport']}' not supported." + logging.error(error_msg) + raise ValueError(error_msg) + # return result return result diff --git a/src/mcpcli/llm_client.py b/src/mcpcli/llm_client.py index 840da3ea..2b15e4f7 100644 --- a/src/mcpcli/llm_client.py +++ b/src/mcpcli/llm_client.py @@ -1,15 +1,23 @@ import logging import os import uuid +import json from typing import Any, Dict, List import ollama +import boto3 from dotenv import load_dotenv from openai import OpenAI # Load environment variables load_dotenv() +BEDROCK_MODEL_IDS = { + "claude-3.5-haiku": "anthropic.claude-3-5-haiku-20241022-v1:0", + "claude-3.5-sonnet": "anthropic.claude-3-5-sonnet-20241022-v2:0", + "nova-lite":"amazon.nova-lite-v1:0", + "nova-pro":"amazon.nova-pro-v1:0", + } class LLMClient: def __init__(self, provider="openai", model="gpt-4o-mini", api_key=None): @@ -17,6 +25,7 @@ def __init__(self, provider="openai", model="gpt-4o-mini", api_key=None): self.provider = provider self.model = model self.api_key = api_key or os.getenv("OPENAI_API_KEY") + self.base_url = None or os.getenv("OPENAI_BASE_URL") # ensure we have the api key for openai if set if provider == "openai" and not self.api_key: @@ -25,7 +34,11 @@ def __init__(self, provider="openai", model="gpt-4o-mini", api_key=None): # check ollama is good if provider == "ollama" and not hasattr(ollama, "chat"): raise ValueError("Ollama is not properly configured in this environment.") - + + # check amazon is good + if provider == "amazon" and not hasattr(boto3, "client"): + raise ValueError("Amazon is not properly configured in this environment.") + def create_completion( self, messages: List[Dict], tools: List = None ) -> Dict[str, Any]: @@ -36,6 +49,9 @@ def create_completion( elif self.provider == "ollama": # perform an ollama completion return self._ollama_completion(messages, tools) + elif self.provider == "amazon": + # perform an amazon completion' + return self._amazon_completion(messages, tools) else: # unsupported providers raise ValueError(f"Unsupported provider: {self.provider}") @@ -44,7 +60,7 @@ def _openai_completion(self, messages: List[Dict], tools: List) -> Dict[str, Any """Handle OpenAI chat completions.""" # get the openai client client = OpenAI(api_key=self.api_key) - + try: # make a request, passing in tools response = client.chat.completions.create( @@ -108,3 +124,105 @@ def _ollama_completion(self, messages: List[Dict], tools: List) -> Dict[str, Any # error logging.error(f"Ollama API Error: {str(e)}") raise ValueError(f"Ollama API Error: {str(e)}") + + def _amazon_completion(self, messages: List[Dict], tools: List) -> Dict[str, Any]: + """Handle Amazon chat completions.""" + client = boto3.client('bedrock-runtime', region_name=os.getenv("AWS_REGION", "us-east-1")) + model_id = BEDROCK_MODEL_IDS.get(self.model, "anthropic.claude-3-5-sonnet-20241022-v2:0") + try: + # Separate system messages from other messages + system_prompts = [] + conversation_messages = [] + + + for msg in messages: + if msg["role"] == "system": + system_prompts.append({"text": msg["content"]}) + else: + # Handle tool results differently + + if isinstance(msg["content"], list) and msg["content"] and "toolResult" in msg["content"][0]: + + if 'toolResult' in msg["content"][0]: + if 'tool_call_result' not in msg['content'][0]['toolResult']['content'][0]['json']: + msg['content'][0]['toolResult']['content'][0]['json']={"tool_call_result":msg['content'][0]['toolResult']['content'][0]['json']} + conversation_messages.append({ + "role": msg["role"], + "content": msg["content"] # Keep the original toolResult structure + }) + else: + conversation_messages.append({ + "role": msg["role"], + "content": [{"text": msg["content"]}] + }) + + # Use the list of system prompt dictionaries directly + # Handle nested text content for assistant messages + for msg in conversation_messages: + if msg["role"] == "assistant" and isinstance(msg["content"], list) and len(msg["content"]) == 1 and isinstance(msg["content"][0]["text"], list): + msg["content"] = msg["content"][0]["text"] + logging.debug("conversation_messages", conversation_messages) + # Convert OpenAI format tools to Amazon Bedrock format + tool_config = { + "tools": [] + } + + if tools: + for tool in tools: + if tool["type"] == "function": + func = tool["function"] + bedrock_tool = { + "toolSpec": { + "name": func["name"].replace("-", "_"), + "description": func.get("description", func["name"]), + "inputSchema": { + "json": func["parameters"] + } + } + } + tool_config["tools"].append(bedrock_tool) + + # Make API call with tools + response = client.converse( + modelId=model_id, + messages=conversation_messages, + toolConfig=tool_config if tools else None, + system=system_prompts + ) + + logging.info(f"Amazon raw response: {response}") + + # Extract the message and tool calls + output_message = response.get('output', {}).get('message', {}) + content_list = output_message.get('content', []) + + # Extract text response and tool calls + response_text = "" + tool_calls = [] + + + + + for content in content_list: + if 'text' in content: + response_text += content['text'] + elif 'toolUse' in content: + tool_use = content['toolUse'] + tool_calls.append({ + 'id': tool_use.get('toolUseId', str(uuid.uuid4())), + 'type': 'function', + 'function': { + 'name': tool_use['name'].replace("_", "-"), + 'arguments': json.dumps(tool_use.get('input', {})) + } + }) + return { + "response": response_text or "No response", + "tool_calls": tool_calls + } + + except Exception as e: + import traceback + traceback.print_exc() + logging.error(f"Amazon API Error: {str(e)}") + raise ValueError(f"Amazon API Error: {str(e)}") diff --git a/src/mcpcli/tools_handler.py b/src/mcpcli/tools_handler.py index 7c717575..111b99ed 100644 --- a/src/mcpcli/tools_handler.py +++ b/src/mcpcli/tools_handler.py @@ -1,5 +1,6 @@ import json import logging +import os import re from typing import Any, Dict, Optional from mcpcli.messages.send_call_tool import send_call_tool @@ -32,6 +33,7 @@ async def handle_tool_call(tool_call, conversation_history, server_streams): """ tool_name = "unknown_tool" raw_arguments = {} + tool_call_id = None try: # Handle object-style tool calls from both OpenAI and Ollama @@ -79,37 +81,50 @@ async def handle_tool_call(tool_call, conversation_history, server_streams): # Format the tool response formatted_response = format_tool_response(tool_response.get("content", [])) logging.debug(f"Tool '{tool_name}' Response: {formatted_response}") - + # Update the conversation history with the tool call # Add the tool call itself (for OpenAI tracking) - conversation_history.append( - { + + if os.environ.get("LLM_PROVIDER") == "amazon": + # Format response for Bedrock + tool_call_id = tool_call.get("id", "unknown tool id") + tool_result = { + "toolUseId": tool_call_id, + #"content": [{"json": tool_response.get("content", [])}] + "content": [{"json": formatted_response[0].get("json", [])}] + } + + + conversation_history.append({ + "role": "user", + "content": [{ + "toolResult": tool_result + }] + }) + else: + # Standard OpenAI format + conversation_history.append({ "role": "assistant", "content": None, - "tool_calls": [ - { - "id": f"call_{tool_name}", - "type": "function", - "function": { - "name": tool_name, - "arguments": json.dumps(tool_args) - if isinstance(tool_args, dict) - else tool_args, - }, - } - ], - } - ) - - # Add the tool response to conversation history - conversation_history.append( - { + "tool_calls": [{ + "id": f"call_{tool_name}", + "type": "function", + "function": { + "name": tool_name, + "arguments": json.dumps(tool_args) if isinstance(tool_args, dict) else tool_args, + }, + }] + }) + + conversation_history.append({ "role": "tool", "name": tool_name, "content": formatted_response, "tool_call_id": f"call_{tool_name}", - } - ) + }) + + + logging.debug("conversation_history",conversation_history) except json.JSONDecodeError: logging.debug( @@ -121,13 +136,30 @@ async def handle_tool_call(tool_call, conversation_history, server_streams): def format_tool_response(response_content): """Format the response content from a tool.""" - if isinstance(response_content, list): - return "\n".join( - item.get("text", "No content") - for item in response_content - if item.get("type") == "text" - ) - return str(response_content) + if os.environ.get("LLM_PROVIDER") == "amazon": + if isinstance(response_content, list): + for item in response_content: + if item.get("type") == "text" and isinstance(item.get("text"), str): + try: + # Replace single quotes with double quotes and 'None' with 'null' + text_content = item["text"].replace("'", '"').replace("None", "null") + parsed_json = json.loads(text_content) + # Convert to format expected by Bedrock + return [{"type": "json", "json": parsed_json}] + except json.JSONDecodeError as e: + logging.error(f"JSON parse error: {e}") + return response_content + + return response_content + else: + # Original formatting for other providers + if isinstance(response_content, list): + return "\n".join( + item.get("text", "No content") + for item in response_content + if item.get("type") == "text" + ) + return str(response_content) async def fetch_tools(read_stream, write_stream): diff --git a/src/mcpcli/transport/sse/__init__.py b/src/mcpcli/transport/sse/__init__.py new file mode 100644 index 00000000..f1bf05b8 --- /dev/null +++ b/src/mcpcli/transport/sse/__init__.py @@ -0,0 +1,5 @@ +"""SSE (Server-Sent Events) transport implementation.""" + +from .sse_server_parameters import SSEServerParameters +from .sse_client import sse_client +__all__ = ["SSEClient", "SSEServerParameters", "SSEServerShutdown"] \ No newline at end of file diff --git a/src/mcpcli/transport/sse/sse_client.py b/src/mcpcli/transport/sse/sse_client.py new file mode 100644 index 00000000..ade1d9ca --- /dev/null +++ b/src/mcpcli/transport/sse/sse_client.py @@ -0,0 +1,101 @@ +import json +import logging +import asyncio +import traceback +import httpx +from contextlib import asynccontextmanager +from urllib.parse import urlparse +import anyio +import sys + +from typing import Optional, Dict, Any +from mcpcli.messages.message_types.json_rpc_message import JSONRPCMessage + +@asynccontextmanager +async def sse_client(url: str): + """建立SSE连接并返回读写流""" + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + # Create memory object streams for reading and writing + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + # Create Asynchttpx client, set timeout and keep connection + limits = httpx.Limits(max_keepalive_connections=5, max_connections=10) + async with httpx.AsyncClient(timeout=None, limits=limits) as client: + # Add shared variable to store message endpoint + message_endpoint = None + + async def sse_reader(): + nonlocal message_endpoint + while True: # Add reconnection logic + try: + logging.debug(f"SSE transport endpoint: {url} connection init") + async with client.stream('GET', f"{url}", timeout=None) as response: + first_lines = [] + async for line in response.aiter_lines(): + # Collect first two lines of data + if len(first_lines) < 2: + first_lines.append(line.strip()) + if len(first_lines) == 2: + # Extract message endpoint from first line + + endpoint_line = first_lines[1] + if endpoint_line.startswith('data:'): + endpoint_data = endpoint_line.replace('data: ', '').strip() + message_endpoint = endpoint_data + logging.debug("Extracted message endpoint:", message_endpoint) + continue + + continue + + if line.strip().startswith('data: '): + data = line.replace('data: ', '').strip() + try: + json_data = json.loads(data) + message = JSONRPCMessage.model_validate(json_data) + await read_stream_writer.send(message) + except json.JSONDecodeError as e: + logging.error(f"JSON decode error: {e}") + except Exception as e: + logging.error(f"Error processing message: {e}") + except httpx.RequestError as e: + logging.error(f"SSE connection error: {e}") + await anyio.sleep(1) # Wait before reconnecting + except Exception as e: + logging.error(f"Unexpected error: {e}") + await anyio.sleep(1) # Wait before reconnecting + + async def message_sender(): + try: + async with write_stream_reader: + async for message in write_stream_reader: + json_data = message.model_dump_json(exclude_none=True) + try: + # Use stored message_endpoint + if message_endpoint: + endpoint = f"{base_url}{message_endpoint}" + logging.debug("Sending message to endpoint:", endpoint,json_data) + response = await client.post( + endpoint, + json=json.loads(json_data), + headers={"Content-Type": "application/json"} + ) + logging.debug(f"Message sent successfully: {response.status_code}") + logging.debug(f"Response: {response.text}") + response.raise_for_status() + except Exception as e: + traceback.print_exc() + logging.error(f"Error sending message: {e}") + except Exception as e: + logging.error(f"Message sender error: {e}") + + try: + async with anyio.create_task_group() as tg: + tg.start_soon(sse_reader) + tg.start_soon(message_sender) + yield read_stream, write_stream + finally: + await read_stream.aclose() + await write_stream.aclose() \ No newline at end of file diff --git a/src/mcpcli/transport/sse/sse_server_parameters.py b/src/mcpcli/transport/sse/sse_server_parameters.py new file mode 100644 index 00000000..b730911e --- /dev/null +++ b/src/mcpcli/transport/sse/sse_server_parameters.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from typing import Optional + +@dataclass +class SSEServerParameters: + """SSE Server Parameters""" + endpoint: str = "http://localhost:8000/sse" + + + @property + def url(self) -> str: + """Return server URL""" + return self.endpoint \ No newline at end of file