diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index aa764bff223..ef8970b19ef 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -13,7 +13,7 @@ from cecli import utils from cecli.change_tracker import ChangeTracker -from cecli.helpers import nested +from cecli.helpers import nested, responses from cecli.helpers.background_commands import BackgroundCommandManager from cecli.helpers.conversation import ConversationService, MessageTag from cecli.helpers.similarity import ( @@ -124,7 +124,7 @@ def _get_agent_config(self): config["command_timeout"] = nested.getter(config, "command_timeout", 30) config["hot_reload"] = nested.getter(config, "hot_reload", False) - config["tools_paths"] = nested.getter(config, "tools_paths", []) + config["tools_paths"] = nested.getter(config, ["tools_paths", "tool_paths"], []) config["tools_includelist"] = nested.getter( config, ["tools_includelist", "tools_whitelist"], [] ) @@ -246,7 +246,7 @@ async def _execute_local_tool_calls(self, tool_calls_list): tool_name = tool_call.function.name result_message = "" try: - if tool_name.lower() in self.write_tools: + if responses.unprefix_tool_name(tool_name)[1].lower() in self.write_tools: used_write_tool = True args_string = tool_call.function.arguments.strip() @@ -738,8 +738,11 @@ async def process_tool_calls(self, tool_call_response): if tool_name: self.last_round_tools.append(tool_name) + content = ( + str(self.partial_response_content) if self.partial_response_content else "" + ) tool_call_str = str(tool_call_copy) - tool_vector = create_bigram_vector((tool_call_str,)) + tool_vector = create_bigram_vector((tool_call_str, content)) tool_vector_norm = normalize_vector(tool_vector) self.tool_call_vectors.append(tool_vector_norm) if self.last_round_tools: @@ -753,6 +756,27 @@ async def process_tool_calls(self, tool_call_response): # Ensure we call base implementation to trigger execution of all tools (native + extracted) return await super().process_tool_calls(tool_call_response) + async def _execute_local_tools(self, tool_calls): + """Execute local tools via ToolRegistry.""" + return await self._execute_local_tool_calls(tool_calls) + + async def _execute_mcp_tools(self, server, tool_calls): + """Execute MCP tools via LiteLLM.""" + responses = [] + for tool_call in tool_calls: + # Use existing _execute_mcp_tool logic + result = await self._execute_mcp_tool( + server, tool_call.function.name, json.loads(tool_call.function.arguments) + ) + responses.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + } + ) + return responses + def get_active_model(self): if self.main_model.agent_model: return self.main_model.agent_model @@ -870,14 +894,15 @@ def _get_repetitive_tools(self): Identifies repetitive tool usage patterns from rounds of tool calls. """ history_len = len(self.tool_usage_history) - if history_len < 5: + if history_len < 1: return set() similarity_repetitive_tools = self._get_repetitive_tools_by_similarity() if self.last_round_tools: last_round_has_write = any( - tool.lower() in self.write_tools for tool in self.last_round_tools + responses.unprefix_tool_name(tool)[1].lower() in self.write_tools + for tool in self.last_round_tools ) if last_round_has_write: # Remove half of the history when a write tool is used @@ -889,7 +914,8 @@ def _get_repetitive_tools(self): return { tool for tool in similarity_repetitive_tools - if tool.lower() in self.read_tools or tool.lower() in self.write_tools + if responses.unprefix_tool_name(tool)[1].lower() in self.read_tools + or responses.unprefix_tool_name(tool)[1].lower() in self.write_tools } def _get_repetitive_tools_by_similarity(self): @@ -990,6 +1016,8 @@ def _generate_tool_context(self, repetitive_tools): ) context_parts.append("\n\n") + repetition_warning = None + if repetitive_tools: if not self.model_kwargs: self.model_kwargs = { @@ -1019,7 +1047,7 @@ def _generate_tool_context(self, repetitive_tools): ) self.model_kwargs["frequency_penalty"] = min(0, max(freq_penalty - 0.15, 0)) - self.model_kwargs["temperature"] = min(self.model_kwargs["temperature"], 1) + self.model_kwargs["temperature"] = max(0, min(self.model_kwargs["temperature"], 1)) # One twentieth of the time, just straight reset the randomness if random.random() < 0.05: self.model_kwargs = {} @@ -1028,11 +1056,11 @@ def _generate_tool_context(self, repetitive_tools): self._last_repetitive_warning_turn = self.turn_count self._last_repetitive_warning_severity += 1 - repetition_warning = f""" -## Repetition Detected -You have been using the following tools repetitively: {', '.join([f'`{t}`' for t in repetitive_tools])}. -Do not repeat the same parameters for these tools in your next turns. Prioritize editing. - """ + repetition_warning = ( + "## Repetition Detected\nYou have used the following tools repetitively:" + f" {', '.join([f'`{t}`' for t in repetitive_tools])}.\nDo not repeat the same" + " parameters for these tools in your next turns. Prioritize editing.\n" + ) if self._last_repetitive_warning_severity > 5: self._last_repetitive_warning_severity = 0 @@ -1079,28 +1107,42 @@ def _generate_tool_context(self, repetitive_tools): ] ) - repetition_warning += f""" -## CRITICAL: Execution Loop Detected -You may be stuck in a cycle. To break the exploration loop and continue making progress, please do the following: -1. **Analyze**: Summarize your findings. Describe how you can stop repeating yourself and make progress. -2. **Reframe**: To help with creativity, include a 2-sentence story about {animal} {verb} {fruit} in your thoughts. -3. **Pivot**: Modify your current exploration strategy. Try alternative methods. Prioritize editing. - """ + repetition_warning += ( + "## CRITICAL: Execution Loop Detected\nYou may be stuck in a cycle. To break" + " the exploration loop and continue making progress, please do the" + " following:\n1. **Analyze**: Summarize your findings. Describe how you can" + " stop repeating yourself and make progress.2. **Reframe**: To help with" + f" creativity, include a 2-sentence story about {animal} {verb} {fruit} in your" + " thoughts.\n3. **Pivot**: Modify your current exploration strategy. Try" + " alternative methods. Prioritize editing.\n" + ) - context_parts.append(repetition_warning) + # context_parts.append(repetition_warning) else: self.model_kwargs = {} self._last_repetitive_warning_severity = min( self._last_repetitive_warning_severity - 1, 0 ) + if repetition_warning: + ConversationService.get_manager(self).add_message( + message_dict=dict(role="user", content=repetition_warning), + tag=MessageTag.CUR, + hash_key=("repetition", "agent"), + mark_for_delete=0, + promotion=ConversationService.get_manager(self).DEFAULT_TAG_PROMOTION_VALUE + 2, + mark_for_demotion=1, + force=True, + ) + context_parts.append("") return "\n".join(context_parts) def _generate_write_context(self): if self.last_round_tools: last_round_has_write = any( - tool.lower() in self.write_tools for tool in self.last_round_tools + responses.unprefix_tool_name(tool)[1].lower() in self.write_tools + for tool in self.last_round_tools ) if last_round_has_write: context_parts = [ diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index 766208ee344..5e7ae2a44e0 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -2394,6 +2394,7 @@ async def send_message(self, inp): return except Exception as e: self.io.tool_error(f"Error processing tool calls: {str(e)}") + self.io.tool_error(traceback.format_exc()) self.reflected_message = True return # Continue without tool processing @@ -2445,29 +2446,40 @@ async def send_message(self, inp): self.reflected_message = test_errors return - async def process_tool_calls(self, tool_call_response): - # Use partial_response_tool_calls if available (populated by consolidate_chunks) - # otherwise try to extract from tool_call_response - original_tool_calls = [] + def _extract_and_prepare_tool_calls(self, tool_call_response): + """ + Unified extraction and preparation of tool calls. + Returns: list of prepared tool calls + """ + # 1. Use partial_response_tool_calls if available if self.partial_response_tool_calls: - original_tool_calls = self.partial_response_tool_calls + tool_calls = self.partial_response_tool_calls + # 2. Extract from tool_call_response elif tool_call_response is not None: - try: - if hasattr(tool_call_response, "choices") and tool_call_response.choices: - message = tool_call_response.choices[0].message - if hasattr(message, "tool_calls") and message.tool_calls: - original_tool_calls = message.tool_calls - except (AttributeError, IndexError): - pass + tool_calls = self._extract_from_response(tool_call_response) + else: + return [] - if not original_tool_calls: - return False + # 3. Expand concatenated JSON + return self._expand_concatenated_json(tool_calls) + + def _extract_from_response(self, response): + """Extract tool calls from various response formats.""" + original_tool_calls = [] + try: + if hasattr(response, "choices") and response.choices: + message = response.choices[0].message + if hasattr(message, "tool_calls") and message.tool_calls: + original_tool_calls = message.tool_calls + except (AttributeError, IndexError): + pass - # Expand any tool calls that have concatenated JSON in their arguments. - # This is necessary because some models (like Gemini) will serialize - # multiple tool calls in this way. + return original_tool_calls + + def _expand_concatenated_json(self, tool_calls): + """Expand concatenated JSON arguments.""" expanded_tool_calls = [] - for tool_call in original_tool_calls: + for tool_call in tool_calls: args_string = tool_call.function.arguments.strip() # If there are no arguments, or it's not a string that looks like it could @@ -2498,268 +2510,270 @@ async def process_tool_calls(self, tool_call_response): new_tool_call.id = f"{getattr(tool_call, 'id', 'call')}-{i}" expanded_tool_calls.append(new_tool_call) - # Collect all tool calls grouped by server - server_tool_calls = self._gather_server_tool_calls(expanded_tool_calls) - - if server_tool_calls and self.num_tool_calls < self.max_tool_calls: - self._print_tool_call_info(server_tool_calls) - - if await self.io.confirm_ask("Run tools?", group_response="Run MCP Tools"): - tool_responses = await self._execute_tool_calls(server_tool_calls) + return expanded_tool_calls - # Add all tool responses - for tool_response in tool_responses: - ConversationService.get_manager(self).add_message( - message_dict=tool_response, - tag=MessageTag.CUR, - hash_key=(tool_response["tool_call_id"], str(time.monotonic_ns())), - promotion=ConversationService.get_manager(self).DEFAULT_TAG_PROMOTION_VALUE, - mark_for_demotion=1, - ) + def _group_tools_by_executor(self, tool_calls): + """ + Group tools by their server instance. + Returns: dict with server instances as keys and lists of tool calls as values. + Uses servers from self.mcp_manager (including LocalServer for local tools). + """ + groups = {} - return True - elif self.num_tool_calls >= self.max_tool_calls: - self.io.tool_warning(f"Only {self.max_tool_calls} tool calls allowed, stopping.") + for tool_call in tool_calls: + # Find which server in mcp_manager handles this tool + server = self._find_mcp_server_for_tool(tool_call) + if server: + if server not in groups: + groups[server] = [] - return False + _, unprefixed_tool_call = responses.unprefix_tool_call(tool_call) + groups[server].append(unprefixed_tool_call) - def _print_tool_call_info(self, server_tool_calls): - """Print information about an MCP tool call.""" - # self.io.tool_output("Preparing to run MCP tools", bold=False) + return groups - for server, tool_calls in server_tool_calls.items(): - for tool_call in tool_calls: - try: - if ToolRegistry.get_tool(tool_call.function.name.lower()): - ToolRegistry.get_tool(tool_call.function.name.lower()).format_output( - coder=self, mcp_server=server, tool_response=tool_call - ) - else: - print_tool_response(coder=self, mcp_server=server, tool_response=tool_call) - except Exception: - self.io.tool_output(f"Tool Output Error: {tool_call.function.name.lower()}") - self.io.tool_error(traceback.format_exc()) - pass - - def _gather_server_tool_calls(self, tool_calls): - """Collect all tool calls grouped by server. - Args: - tool_calls: List of tool calls from the LLM response - - Returns: - dict: Dictionary mapping servers to their respective tool calls - """ + def _find_mcp_server_for_tool(self, tool_call): + """Find which MCP server handles this tool.""" if not self.mcp_tools or len(self.mcp_tools) == 0: return None - server_tool_calls = {} - tool_id_set = set() + # Unprefix the tool name to get the server name and unprefixed tool name + server_name_from_prefix, unprefixed_tool_name = responses.unprefix_tool_name( + nested.getter(tool_call, "function.name") + ) - for tool_call in tool_calls: - # LLM APIs sometimes return duplicates and that's annoying part 3 - if tool_call.get("id") in tool_id_set: - continue + # Check if this tool_call matches any MCP tool + for server_name, server_tools in self.mcp_tools: + for tool in server_tools: + tool_name_from_schema = nested.getter(tool, "function.name") + if ( + tool_name_from_schema + and tool_name_from_schema.lower() == unprefixed_tool_name.lower() + ): + # Find the McpServer instance that will be used for communication + for server in self.mcp_manager: + if server.name == server_name and ( + not server_name_from_prefix or server.name == server_name_from_prefix + ): + return server - tool_id_set.add(tool_call.get("id")) + return None - # Check if this tool_call matches any MCP tool - for server_name, server_tools in self.mcp_tools: - for tool in server_tools: - tool_name_from_schema = tool.get("function", {}).get("name") - if ( - tool_name_from_schema - and tool_name_from_schema.lower() == tool_call.function.name.lower() - ): - # Find the McpServer instance that will be used for communication - for server in self.mcp_manager: - if server.name == server_name: - if server not in server_tool_calls: - server_tool_calls[server] = [] - server_tool_calls[server].append(tool_call) - break - - return server_tool_calls - - async def _execute_tool_calls(self, tool_calls): - """Process tool calls from the response and execute them if they match MCP tools. - Returns a list of tool response messages.""" - tool_responses = [] + async def _execute_tool_groups(self, tool_groups): + """Execute all tool groups.""" + all_responses = {} - # Define the coroutine to execute all tool calls for a single server - async def _exec_server_tools(server, tool_calls_list): + # Execute tools for each server + for server, tool_calls in tool_groups.items(): + # Check if this server is an instance of LocalServer (local tools) if isinstance(server, LocalServer): - if hasattr(self, "_execute_local_tool_calls"): - return await self._execute_local_tool_calls(tool_calls_list) - else: - # This coder doesn't support local tools, return errors for all calls - error_responses = [] - for tool_call in tool_calls_list: - error_responses.append( - { - "role": "tool", - "tool_call_id": tool_call.id, - "content": ( - f"Coder does not support local tool: {tool_call.function.name}" - ), - } - ) - return error_responses - - tool_responses = [] - try: - # Connect to the server once - session = await server.connect() - tool_id_set = set() - - # Execute all tool calls for this server - for tool_call in tool_calls_list: - # LLM APIs sometimes return duplicates and that's annoying part 4 - if tool_call.id in tool_id_set: - continue - - tool_id_set.add(tool_call.id) + # Local tools - use _execute_local_tools + local_responses = await self._execute_local_tools(tool_calls) + all_responses[server] = local_responses + else: + # MCP tools - use _execute_mcp_tools + mcp_responses = await self._execute_mcp_tools(server, tool_calls) + all_responses[server] = mcp_responses - try: - # Arguments can be a stream of JSON objects. - # We need to parse them and run a tool call for each. - args_string = tool_call.function.arguments.strip() - parsed_args_list = [] - if args_string: - json_chunks = utils.split_concatenated_json(args_string) - for chunk in json_chunks: - try: - parsed_args_list.append(json.loads(chunk)) - except json.JSONDecodeError: - self.io.tool_warning( - "Malformed JSON arguments in tool" - f" {tool_call.function.name}: {chunk}" - ) - continue + return all_responses - if not parsed_args_list and not args_string: - parsed_args_list.append({}) # For tool calls with no arguments + async def _execute_local_tools(self, tool_calls): + """Execute local tools via ToolRegistry.""" + # Default implementation returns errors + # AgentCoder will override this + error_responses = [] + for tool_call in tool_calls: + error_responses.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": f"Coder does not support local tool: {tool_call.function.name}", + } + ) + return error_responses - all_results_content = [] - for args in parsed_args_list: - new_tool_call = copy_tool_call(tool_call) - new_tool_call.function.arguments = json.dumps(args) + async def _execute_mcp_tools(self, server, tool_calls): + """Execute MCP tools via LiteLLM.""" + tool_responses = [] + try: + # Connect to the server once + session = await server.connect() + tool_id_set = set() - if not await HookIntegration.call_pre_tool_hooks( - self, new_tool_call.function.name, args - ): - self.io.tool_warning("Tool call skipped by pre-tool call hook") - all_results_content.append("Tool Request Aborted.") - continue + # Execute all tool calls for this server + for tool_call in tool_calls: + # LLM APIs sometimes return duplicates and that's annoying part 4 + if tool_call.id in tool_id_set: + continue - call_result = await experimental_mcp_client.call_openai_tool( - session=session, - openai_tool=new_tool_call, - ) + tool_id_set.add(tool_call.id) - content_parts = [] - if call_result.content: - for item in call_result.content: - if hasattr(item, "resource"): # EmbeddedResource - resource = item.resource - if hasattr(resource, "text"): # TextResourceContents - content_parts.append(resource.text) - elif hasattr(resource, "blob"): # BlobResourceContents - try: - decoded_blob = base64.b64decode( - resource.blob - ).decode("utf-8") - content_parts.append(decoded_blob) - except (UnicodeDecodeError, TypeError): - # Handle non-text blobs gracefully - name = getattr(resource, "name", "unnamed") - mime_type = getattr( - resource, "mimeType", "unknown mime type" - ) - content_parts.append( - "[embedded binary resource:" - f" {name} ({mime_type})]" - ) - elif hasattr(item, "text"): # TextContent - content_parts.append(item.text) - - result_text = "".join(content_parts) - - if not await HookIntegration.call_post_tool_hooks( - self, new_tool_call.function.name, args, result_text - ): + try: + # Arguments can be a stream of JSON objects. + # We need to parse them and run a tool call for each. + args_string = tool_call.function.arguments.strip() + parsed_args_list = [] + if args_string: + json_chunks = utils.split_concatenated_json(args_string) + for chunk in json_chunks: + try: + parsed_args_list.append(json.loads(chunk)) + except json.JSONDecodeError: self.io.tool_warning( - "Tool call output skipped by post-tool call hook" + "Malformed JSON arguments in tool" + f" {tool_call.function.name}: {chunk}" ) - all_results_content.append("Tool Response Redacted.") continue - all_results_content.append(result_text) + if not parsed_args_list and not args_string: + parsed_args_list.append({}) # For tool calls with no arguments - tool_responses.append( - { - "role": "tool", - "tool_call_id": tool_call.id, - "content": "\n\n".join(all_results_content), - } - ) + all_results_content = [] + for args in parsed_args_list: + new_tool_call = copy_tool_call(tool_call) + new_tool_call.function.arguments = json.dumps(args) - except Exception as e: - tool_error = f"Error executing tool call {tool_call.function.name}: \n{e}" - self.io.tool_warning( - f"Executing {tool_call.function.name} on {server.name} failed: \n " - f" Error: {e}\n" - ) - tool_responses.append( - {"role": "tool", "tool_call_id": tool_call.id, "content": tool_error} + if not await HookIntegration.call_pre_tool_hooks( + self, new_tool_call.function.name, args + ): + self.io.tool_warning("Tool call skipped by pre-tool call hook") + all_results_content.append("Tool Request Aborted.") + continue + + call_result = await experimental_mcp_client.call_openai_tool( + session=session, + openai_tool=new_tool_call, ) - except httpx.RemoteProtocolError as e: - connection_error = f"Server {server.name} disconnected unexpectedly: {e}" - self.io.tool_warning(connection_error) - for tool_call in tool_calls_list: + + content_parts = [] + if call_result.content: + for item in call_result.content: + if hasattr(item, "resource"): # EmbeddedResource + resource = item.resource + if hasattr(resource, "text"): # TextResourceContents + content_parts.append(resource.text) + elif hasattr(resource, "blob"): # BlobResourceContents + try: + decoded_blob = base64.b64decode(resource.blob).decode( + "utf-8" + ) + content_parts.append(decoded_blob) + except (UnicodeDecodeError, TypeError): + # Handle non-text blobs gracefully + name = getattr(resource, "name", "unnamed") + mime_type = getattr( + resource, "mimeType", "unknown mime type" + ) + content_parts.append( + f"[embedded binary resource: {name} ({mime_type})]" + ) + elif hasattr(item, "text"): # TextContent + content_parts.append(item.text) + + result_text = "".join(content_parts) + + if not await HookIntegration.call_post_tool_hooks( + self, new_tool_call.function.name, args, result_text + ): + self.io.tool_warning("Tool call output skipped by post-tool call hook") + all_results_content.append("Tool Response Redacted.") + continue + + all_results_content.append(result_text) + tool_responses.append( - {"role": "tool", "tool_call_id": tool_call.id, "content": connection_error} + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": "\n\n".join(all_results_content), + } + ) + + except Exception as e: + tool_error = f"Error executing tool call {tool_call.function.name}: \n{e}" + self.io.tool_warning( + f"Executing {tool_call.function.name} on {server.name} failed: \n " + f" Error: {e}\n" ) - except Exception as e: - connection_error = f"Could not connect to server {server.name}\n{e}" - self.io.tool_warning(connection_error) - for tool_call in tool_calls_list: tool_responses.append( - {"role": "tool", "tool_call_id": tool_call.id, "content": connection_error} + {"role": "tool", "tool_call_id": tool_call.id, "content": tool_error} ) + except httpx.RemoteProtocolError as e: + connection_error = f"Server {server.name} disconnected unexpectedly: {e}" + self.io.tool_warning(connection_error) + for tool_call in tool_calls: + tool_responses.append( + {"role": "tool", "tool_call_id": tool_call.id, "content": connection_error} + ) + except Exception as e: + connection_error = f"Could not connect to server {server.name}\n{e}" + self.io.tool_warning(connection_error) + for tool_call in tool_calls: + tool_responses.append( + {"role": "tool", "tool_call_id": tool_call.id, "content": connection_error} + ) - return tool_responses - - # Execute all tool calls concurrently - async def _execute_all_tool_calls(): - tasks = [] - for server, tool_calls_list in tool_calls.items(): - tasks.append(_exec_server_tools(server, tool_calls_list)) - # Wait for all tasks to complete - results = await asyncio.gather(*tasks) - return results - - # Run the async execution and collect results - if tool_calls: - all_results = [] - max_retries = 3 - for i in range(max_retries): - try: - all_results = await _execute_all_tool_calls() - break - except asyncio.exceptions.CancelledError: - if i < max_retries - 1: - await asyncio.sleep(0.1) # Brief pause before retrying - else: - self.io.tool_warning( - "MCP tool execution failed after multiple retries due to cancellation." - ) - all_results = [] + return tool_responses - # Flatten the results from all servers - for server_results in all_results: - tool_responses.extend(server_results) + async def process_tool_calls(self, tool_call_response): + """Simplified main entry point.""" + # Check if max tool calls exceeded + if self.num_tool_calls >= self.max_tool_calls: + self.io.tool_warning(f"Only {self.max_tool_calls} tool calls allowed, stopping.") + return False - return tool_responses + # 1. Extract and prepare tool calls + prepared_calls = self._extract_and_prepare_tool_calls(tool_call_response) + if not prepared_calls: + return False + + # 2. Group by executor + tool_groups = self._group_tools_by_executor(prepared_calls) + + # 3. Print tool call information + if tool_groups: + self._print_tool_call_info(server_tool_calls=tool_groups) + + # 4. Ask for user confirmation + if not await self.io.confirm_ask("Run tools?", group_response="Run MCP Tools"): + return False + + # 5. Execute tools + tool_responses_by_server = await self._execute_tool_groups(tool_groups) + + # 6. Add responses to conversation (re-prefixing if necessary) + tool_responses = [] + for server, server_responses in tool_responses_by_server.items(): + for tool_response in server_responses: + tool_responses.append(tool_response) + + ConversationService.get_manager(self).add_message( + message_dict=tool_response, + tag=MessageTag.CUR, + hash_key=(tool_response["tool_call_id"], str(time.monotonic_ns())), + promotion=ConversationService.get_manager(self).DEFAULT_TAG_PROMOTION_VALUE, + mark_for_demotion=1, + ) + + return bool(tool_responses) + + def _print_tool_call_info(self, server_tool_calls): + """Print information about an MCP tool call.""" + # self.io.tool_output("Preparing to run MCP tools", bold=False) + + for server, tool_calls in server_tool_calls.items(): + for tool_call in tool_calls: + try: + if ToolRegistry.get_tool(tool_call.function.name.lower()): + ToolRegistry.get_tool(tool_call.function.name.lower()).format_output( + coder=self, mcp_server=server, tool_response=tool_call + ) + else: + print_tool_response(coder=self, mcp_server=server, tool_response=tool_call) + except Exception: + self.io.tool_output(f"Tool Output Error: {tool_call.function.name.lower()}") + self.io.tool_error(traceback.format_exc()) + pass async def initialize_mcp_tools(self): """ @@ -2779,11 +2793,14 @@ def mcp_tools(self, value): raise AttributeError("mcp_tools is read only.") def get_tool_list(self): - """Get a flattened list of all MCP tools.""" + """Get a flattened list of all MCP tools with server prefixes.""" tool_list = [] if self.mcp_tools: - for _, server_tools in self.mcp_tools: - tool_list.extend(server_tools) + for server_name, server_tools in self.mcp_tools: + for tool in server_tools: + # Prefix the tool name with server name + prefixed_tool = responses.prefix_tool_call(tool, server_name) + tool_list.append(prefixed_tool) return tool_list async def reply_completed(self): diff --git a/cecli/helpers/conversation/files.py b/cecli/helpers/conversation/files.py index b6a710000c5..f0cc62ed77c 100644 --- a/cecli/helpers/conversation/files.py +++ b/cecli/helpers/conversation/files.py @@ -353,7 +353,9 @@ def get_all_tracked_files(self) -> set: image_files = set(self._image_files.keys()) return regular_files.union(image_files) - def update_file_context(self, file_path: str, start_line: int, end_line: int) -> None: + def update_file_context( + self, file_path: str, start_line: int, end_line: int, auto_remove=True + ) -> None: """ Update numbered contexts for a file with a new range. @@ -399,13 +401,8 @@ def update_file_context(self, file_path: str, start_line: int, end_line: int) -> # Remove using hash key (file_context, abs_fname) coder = self.get_coder() - if coder: - ConversationService.get_manager(coder).remove_message_by_hash_key( - ("file_context_user", abs_fname) - ) - ConversationService.get_manager(coder).remove_message_by_hash_key( - ("file_context_assistant", abs_fname) - ) + if coder and auto_remove: + self.remove_file_messages(abs_fname) def get_file_context(self, file_path: str) -> str: """ @@ -482,6 +479,25 @@ def remove_file_context(self, file_path: str) -> None: ("file_context_assistant", abs_fname) ) + def remove_file_messages(self, file_path: str) -> None: + """ + Remove all file messages for a file path. + + Args: + file_path: Absolute file path + """ + abs_fname = os.path.abspath(file_path) + + # Remove using hash key (file_context, abs_fname) + coder = self.get_coder() + if coder: + ConversationService.get_manager(coder).remove_message_by_hash_key( + ("file_context_user", abs_fname) + ) + ConversationService.get_manager(coder).remove_message_by_hash_key( + ("file_context_assistant", abs_fname) + ) + def clear_all_numbered_contexts(self) -> None: """Clear all numbered contexts for all files.""" self._numbered_contexts.clear() diff --git a/cecli/helpers/conversation/integration.py b/cecli/helpers/conversation/integration.py index 40dac2a64aa..7f602645d79 100644 --- a/cecli/helpers/conversation/integration.py +++ b/cecli/helpers/conversation/integration.py @@ -682,7 +682,7 @@ def add_chat_files_messages(self) -> Dict[str, Any]: return result - def add_file_context_messages(self) -> None: + def add_file_context_messages(self, promote_messages=True) -> None: """ Create and insert FILE_CONTEXTS messages based on cached contexts. """ @@ -721,8 +721,12 @@ def add_file_context_messages(self) -> None: tag=MessageTag.FILE_CONTEXTS, hash_key=("file_context_user", file_path), force=True, - promotion=ConversationService.get_manager(coder).DEFAULT_TAG_PROMOTION_VALUE, - mark_for_demotion=1, + promotion=( + ConversationService.get_manager(coder).DEFAULT_TAG_PROMOTION_VALUE + if promote_messages + else None + ), + mark_for_demotion=1 if promote_messages else None, ) ConversationService.get_manager(coder).add_message( @@ -730,8 +734,12 @@ def add_file_context_messages(self) -> None: tag=MessageTag.FILE_CONTEXTS, hash_key=("file_context_assistant", file_path), force=True, - promotion=ConversationService.get_manager(coder).DEFAULT_TAG_PROMOTION_VALUE, - mark_for_demotion=1, + promotion=( + ConversationService.get_manager(coder).DEFAULT_TAG_PROMOTION_VALUE + if promote_messages + else None + ), + mark_for_demotion=1 if promote_messages else None, ) def reset(self) -> None: diff --git a/cecli/helpers/nested.py b/cecli/helpers/nested.py index 624a001df4b..c77be72b717 100644 --- a/cecli/helpers/nested.py +++ b/cecli/helpers/nested.py @@ -54,7 +54,7 @@ def arg_resolver(obj: Union[List[Any], Dict[str, Any], Any], key: str, default: def getter( data: Union[List[Any], Dict[str, Any], Any], path: Union[str, List[str]], default: Any = None ) -> Any: - """Safely access nested dicts and lists using normalized dot-notation.""" + """Safely access nested dicts, lists, and objects using normalized dot-notation.""" if data is None: return default diff --git a/cecli/helpers/requests.py b/cecli/helpers/requests.py index 63a6463dd6f..f2ed199c937 100644 --- a/cecli/helpers/requests.py +++ b/cecli/helpers/requests.py @@ -125,10 +125,34 @@ def flush_user_messages(): return result +def add_continue_for_no_prefill(model, messages): + """Add a 'Continue' user message for models that don't support assistant prefill. + + Args: + model: The model object with info dictionary + messages: List of message dictionaries + + Returns: + List of messages with 'Continue' message added if model doesn't support assistant prefill + and the last message is not already a user message + """ + # Check if model doesn't support assistant prefill + # If not, inject a dummy user message with content "Continue" + # but only if the last message is not already a user message + if not model.info.get("supports_assistant_prefill", False): + # Only add "Continue" if the last message is not a user message + if not messages or messages[-1].get("role") != "user": + # Add a user message with content "Continue" to the messages list + messages.append({"role": "user", "content": "Continue"}) + + return messages + + def model_request_parser(model, messages): messages = thought_signature(model, messages) messages = remove_empty_tool_calls(messages) messages = concatenate_user_messages(messages) messages = ensure_alternating_roles(messages) messages = add_reasoning_content(messages) + messages = add_continue_for_no_prefill(model, messages) return messages diff --git a/cecli/helpers/responses.py b/cecli/helpers/responses.py index efeb69dac51..6f4762e2a88 100644 --- a/cecli/helpers/responses.py +++ b/cecli/helpers/responses.py @@ -130,3 +130,116 @@ def extract_tools_from_content_xml(content: str) -> Optional[List[ChatCompletion return extracted_calls if extracted_calls else None except Exception: return None + + +def prefix_tool_name(server_name: str, tool_name: str) -> str: + """ + Prefix a tool name with the server name. + + Args: + server_name: Name of the MCP server + tool_name: Original tool name + + Returns: + Prefixed tool name in format "{server_name}--{tool_name}" + """ + return f"{server_name}--{tool_name}" + + +def unprefix_tool_name(prefixed_name: str) -> tuple[str, str]: + """ + Unprefix a tool name that may have a server prefix. + + Args: + prefixed_name: Tool name that may be prefixed with "{server_name}--{tool_name}" + + Returns: + Tuple of (server_name, tool_name) where server_name may be empty string + if no prefix is found + """ + # Split on the first double dash + if "--" in prefixed_name: + # Find the first double dash + first_dash_index = prefixed_name.find("--") + server_name = prefixed_name[:first_dash_index] + tool_name = prefixed_name[first_dash_index + 2 :] # +2 to skip both dashes + return server_name, tool_name + return "", prefixed_name + + +def prefix_tool_call(tool_call, server_name: str): + """ + Prefix the function name in a tool call. + + Args: + tool_call: Tool call (dict or ChatCompletionMessageToolCall) with 'function' key/attribute + server_name: Name of the MCP server + + Returns: + New tool call with prefixed function name (same type as input) + """ + # Handle ChatCompletionMessageToolCall objects + if hasattr(tool_call, "function") and hasattr(tool_call.function, "name"): + # Create a copy of the tool call object + result = ChatCompletionMessageToolCall( + id=tool_call.id, + type=tool_call.type, + function=Function( + name=prefix_tool_name(server_name, tool_call.function.name), + arguments=tool_call.function.arguments, + ), + ) + return result + + # Handle dictionaries + if not isinstance(tool_call, dict): + return tool_call + + # Create a deep copy to avoid modifying the original + result = tool_call.copy() + if "function" in result and isinstance(result["function"], dict): + result["function"] = result["function"].copy() + if "name" in result["function"]: + result["function"]["name"] = prefix_tool_name(server_name, result["function"]["name"]) + + return result + + +def unprefix_tool_call(tool_call): + """ + Unprefix the function name in a tool call. + + Args: + tool_call: Tool call (dict or ChatCompletionMessageToolCall) with 'function' key/attribute + + Returns: + Tuple of (server_name, unprefixed_tool_call) where server_name may be empty string + if no prefix is found (same type as input) + """ + # Handle ChatCompletionMessageToolCall objects + if hasattr(tool_call, "function") and hasattr(tool_call.function, "name"): + server_name, unprefixed_name = unprefix_tool_name(tool_call.function.name) + + # Create a copy of the tool call object with unprefixed name + result = ChatCompletionMessageToolCall( + id=tool_call.id, + type=tool_call.type, + function=Function(name=unprefixed_name, arguments=tool_call.function.arguments), + ) + return server_name, result + + # Handle dictionaries + if not isinstance(tool_call, dict): + return "", tool_call + + # Create a deep copy to avoid modifying the original + result = tool_call.copy() + server_name = "" + + if "function" in result and isinstance(result["function"], dict): + result["function"] = result["function"].copy() + if "name" in result["function"]: + server_name, unprefixed_name = unprefix_tool_name(result["function"]["name"]) + result["function"]["name"] = unprefixed_name + + return server_name, result diff --git a/cecli/models.py b/cecli/models.py index e2b8bd5db81..3cf7bb29627 100644 --- a/cecli/models.py +++ b/cecli/models.py @@ -1000,6 +1000,10 @@ async def send_completion( else: temperature = float(self.use_temperature) kwargs["temperature"] = temperature + else: + if override_kwargs and override_kwargs.get("temperature", None): + override_kwargs.pop("temperature", None) + effective_tools = tools if effective_tools is None and functions: effective_tools = [dict(type="function", function=f) for f in functions] diff --git a/cecli/prompts/agent.yml b/cecli/prompts/agent.yml index f51717fa77f..1e3d7a7d924 100644 --- a/cecli/prompts/agent.yml +++ b/cecli/prompts/agent.yml @@ -24,7 +24,7 @@ main_system: | ### 1. FILE FORMAT - Files are provided in "hashline" format. Each line starts with a content hash followed by `::`. + Files are provided in "hashline" format. Each line starts with a case-sensitive content hash followed by `::`. **Example File Format :** il9n::#!/usr/bin/env python3 diff --git a/cecli/prompts/hashline.yml b/cecli/prompts/hashline.yml index 9a812cc6b04..0acc0eb3f52 100644 --- a/cecli/prompts/hashline.yml +++ b/cecli/prompts/hashline.yml @@ -6,7 +6,7 @@ main_system: | Act as an expert software developer. Plan carefully, explain your logic briefly, and execute via LOCATE/CONTENTS blocks. ### 1. FILE FORMAT - Files are provided in "hashline" format. Each line starts with a content hash followed by `::`. + Files are provided in "hashline" format. Each line starts with a case-sensitive content hash followed by `::`. **Example File Format :** il9n::#!/usr/bin/env python3 diff --git a/cecli/tools/delete_text.py b/cecli/tools/delete_text.py index a57e5df8ae4..f5d0b2206f3 100644 --- a/cecli/tools/delete_text.py +++ b/cecli/tools/delete_text.py @@ -11,6 +11,7 @@ class Tool(BaseTool): NORM_NAME = "deletetext" + TRACK_INVOCATIONS = False SCHEMA = { "type": "function", "function": { diff --git a/cecli/tools/insert_text.py b/cecli/tools/insert_text.py index 535428cf5f0..e73c33155ff 100644 --- a/cecli/tools/insert_text.py +++ b/cecli/tools/insert_text.py @@ -14,6 +14,7 @@ class Tool(BaseTool): NORM_NAME = "inserttext" + TRACK_INVOCATIONS = False SCHEMA = { "type": "function", "function": { diff --git a/cecli/tools/replace_text.py b/cecli/tools/replace_text.py index c4ac90871d5..3d9ba450f55 100644 --- a/cecli/tools/replace_text.py +++ b/cecli/tools/replace_text.py @@ -19,6 +19,7 @@ class Tool(BaseTool): NORM_NAME = "replacetext" + TRACK_INVOCATIONS = False SCHEMA = { "type": "function", "function": { diff --git a/cecli/tools/show_context.py b/cecli/tools/show_context.py index 7826fb952e6..204cd94edc7 100644 --- a/cecli/tools/show_context.py +++ b/cecli/tools/show_context.py @@ -82,7 +82,9 @@ def execute(cls, coder, show, **kwargs): Accepts an array of show operations to perform. Uses utility functions for path resolution and error handling. """ - tool_name = "showcontext" + tool_name = "ShowContext" + already_up_to_date = False + try: # 1. Validate show parameter if not isinstance(show, list): @@ -239,14 +241,35 @@ def execute(cls, coder, show, **kwargs): # Note: start_line_idx and end_line_idx are 0-based, convert to 1-based for hashline start_line = start_line_idx + 1 # Convert to 1-based end_line = end_line_idx + 1 # Convert to 1-based + + original_context_content = ConversationService.get_files(coder).get_file_context( + abs_path + ) ConversationService.get_files(coder).update_file_context( - abs_path, start_line, end_line + abs_path, start_line, end_line, auto_remove=False + ) + new_context_content = ConversationService.get_files(coder).get_file_context( + abs_path ) + + if original_context_content and original_context_content == new_context_content: + already_up_to_date = True + else: + ConversationService.get_files(coder).remove_file_messages(abs_path) ConversationService.get_chunks(coder).add_file_context_messages() + # Log success and return the formatted context directly coder.edit_allowed = True - coder.io.tool_output(f"Successfully retrieved context for {len(show)} file(s)") - return f"Successfully retrieved most recent context for {len(show)} file(s)" + + if already_up_to_date: + coder.io.tool_output("File contents already up to date") + return ( + "File contents already up to date. Please proceed with your task. " + "Do not call ShowContext again until you edit the file." + ) + else: + coder.io.tool_output(f"Successfully retrieved context for {len(show)} file(s)") + return f"Successfully retrieved most recent contents for {len(show)} file(s)" except ToolError as e: # Handle expected errors raised by utility functions or validation diff --git a/cecli/tools/utils/base_tool.py b/cecli/tools/utils/base_tool.py index 676fd337094..e801d31373d 100644 --- a/cecli/tools/utils/base_tool.py +++ b/cecli/tools/utils/base_tool.py @@ -12,6 +12,10 @@ class BaseTool(ABC): NORM_NAME = None SCHEMA = None + # Invocation tracking for detecting repeated tool calls + _invocations = {} # Dict to store last 3 invocations per tool + TRACK_INVOCATIONS = True # Default to True, subclasses can override + @classmethod @abstractmethod def execute(cls, coder, **params): @@ -54,6 +58,39 @@ def process_response(cls, coder, params): ) return handle_tool_error(coder, tool_name, ValueError(error_msg)) + # Check for repeated invocations if TRACK_INVOCATIONS is enabled + if cls.TRACK_INVOCATIONS: + tool_name = None + if cls.SCHEMA and "function" in cls.SCHEMA: + tool_name = cls.SCHEMA["function"].get("name", "Unknown Tool") + else: + tool_name = cls.__name__ + + # Initialize invocation tracking for this tool if not exists + if tool_name not in cls._invocations: + cls._invocations[tool_name] = [] + + # Check if current parameters match any of the last 3 invocations + current_params_tuple = tuple( + sorted(params.items()) + ) # Convert to sorted tuple for comparison + + for i, (prev_params_tuple, _) in enumerate(cls._invocations[tool_name]): + if prev_params_tuple == current_params_tuple: + error_msg = ( + f"Tool '{tool_name}' has been called with identical parameters recently. " + "This request is denied to prevent repeated operations." + ) + return handle_tool_error(coder, tool_name, ValueError(error_msg)) + + # Add current invocation to history (keeping only last 3) + cls._invocations[tool_name].append((current_params_tuple, params)) + if len(cls._invocations[tool_name]) > 3: + cls._invocations[tool_name] = cls._invocations[tool_name][-3:] + else: + # When TRACK_INVOCATIONS is False, clear all invocation history + # This indicates the job is making progress, so reset tracking for all tools + cls._invocations.clear() try: return cls.execute(coder, **params) except Exception as e: diff --git a/cecli/tools/utils/helpers.py b/cecli/tools/utils/helpers.py index 187a7e5e728..45e123d91f5 100644 --- a/cecli/tools/utils/helpers.py +++ b/cecli/tools/utils/helpers.py @@ -304,7 +304,7 @@ def handle_tool_error(coder, tool_name, e, add_traceback=True): error_message += f"\n{traceback.format_exc()}" coder.io.tool_error(error_message) # Return only the core error message to the LLM for brevity - return f"Error: {str(e)}" + return f"Error in {tool_name}: {str(e)}" def format_tool_result( diff --git a/cecli/tools/utils/registry.py b/cecli/tools/utils/registry.py index 24f852ddaa4..27ad9e4c15e 100644 --- a/cecli/tools/utils/registry.py +++ b/cecli/tools/utils/registry.py @@ -52,9 +52,9 @@ def build_registry(cls, agent_config: Optional[Dict] = None) -> Dict[str, Type]: agent_config = {} # Load tools from tool_paths if specified - tool_paths = agent_config.get("tool_paths", []) + tools_paths = agent_config.get("tools_paths", agent_config.get("tool_paths", [])) - for tool_path in tool_paths: + for tool_path in tools_paths: path = Path(tool_path) if path.is_dir(): # Find all Python files in the directory diff --git a/tests/basic/test_coder.py b/tests/basic/test_coder.py index 45fa8f1c59f..f780382ff3a 100644 --- a/tests/basic/test_coder.py +++ b/tests/basic/test_coder.py @@ -1527,22 +1527,25 @@ async def test_process_tool_calls_with_tools(self): manager._server_tools[mock_server.name] = [{"function": {"name": "test_tool"}}] coder = await Coder.create(self.GPT35, "diff", io=io, mcp_manager=manager) - # Mock _execute_tool_calls to return tool responses - tool_responses = [ - { - "role": "tool", - "tool_call_id": "test_id", - "content": "Tool execution result", - } - ] - coder._execute_tool_calls = AsyncMock(return_value=tool_responses) + # Mock _execute_tool_groups to return tool responses + # Note: _execute_tool_groups now returns a dict keyed by server + tool_responses = { + mock_server: [ + { + "role": "tool", + "tool_call_id": "test_id", + "content": "Tool execution result", + } + ] + } + coder._execute_tool_groups = AsyncMock(return_value=tool_responses) # Test process_tool_calls result = await coder.process_tool_calls(response) assert result - # Verify that _execute_tool_calls was called - coder._execute_tool_calls.assert_called_once() + # Verify that _execute_tool_groups was called + coder._execute_tool_groups.assert_called_once() # Verify that the tool response message was added assert len(coder.cur_messages) == 1 @@ -1669,17 +1672,21 @@ async def test_execute_tool_calls(self, mock_call_tool): mock_result.content = [mock_content_item] mock_call_tool.return_value = mock_result - # Test _execute_tool_calls directly - result = await coder._execute_tool_calls(server_tool_calls) + # Test _execute_tool_groups directly + result = await coder._execute_tool_groups(server_tool_calls) # Verify that server.connect was called mock_server.connect.assert_called_once() # Verify that the correct tool responses were returned + # _execute_tool_groups now returns a dict keyed by server assert len(result) == 1 - assert result[0]["role"] == "tool" - assert result[0]["tool_call_id"] == "test_id" - assert result[0]["content"] == "Tool execution result" + assert mock_server in result + server_responses = result[mock_server] + assert len(server_responses) == 1 + assert server_responses[0]["role"] == "tool" + assert server_responses[0]["tool_call_id"] == "test_id" + assert server_responses[0]["content"] == "Tool execution result" async def test_auto_commit_with_none_content_message(self): """ @@ -1764,20 +1771,23 @@ async def test_execute_tool_calls_multiple_content(self, mock_call_openai_tool): mock_call_result.content = [mock_content1, mock_content2] mock_call_openai_tool.return_value = mock_call_result - # Test _execute_tool_calls directly - result = await coder._execute_tool_calls(server_tool_calls) - + # Test _execute_tool_groups directly + result = await coder._execute_tool_groups(server_tool_calls) # Verify that call_openai_tool was called mock_call_openai_tool.assert_called_once() # Verify that the correct tool responses were returned + # _execute_tool_groups now returns a dict keyed by server assert len(result) == 1 - assert result[0]["role"] == "tool" - assert result[0]["tool_call_id"] == "test_id" + assert mock_server in result + server_responses = result[mock_server] + assert len(server_responses) == 1 + assert server_responses[0]["role"] == "tool" + assert server_responses[0]["tool_call_id"] == "test_id" # This will fail with the current code, which is the point of the test. # The current code returns a hardcoded string. # A fixed version should concatenate the text from all content blocks. - assert result[0]["content"] == "First part. Second part." + assert server_responses[0]["content"] == "First part. Second part." @patch( "cecli.coders.base_coder.experimental_mcp_client.call_openai_tool", @@ -1835,19 +1845,23 @@ async def test_execute_tool_calls_blob_content(self, mock_call_openai_tool): ] mock_call_openai_tool.return_value = mock_call_result - # Test _execute_tool_calls directly - result = await coder._execute_tool_calls(server_tool_calls) + # Test _execute_tool_groups directly + result = await coder._execute_tool_groups(server_tool_calls) # Verify that call_openai_tool was called mock_call_openai_tool.assert_called_once() # Verify that the correct tool responses were returned + # _execute_tool_groups now returns a dict keyed by server assert len(result) == 1 - assert result[0]["role"] == "tool" - assert result[0]["tool_call_id"] == "test_id" + assert mock_server in result + server_responses = result[mock_server] + assert len(server_responses) == 1 + assert server_responses[0]["role"] == "tool" + assert server_responses[0]["tool_call_id"] == "test_id" expected_content = ( "Plain text. Hello from blob! [embedded binary resource: binary.dat" " (application/octet-stream)]" ) - assert result[0]["content"] == expected_content + assert server_responses[0]["content"] == expected_content diff --git a/tests/tools/test_insert_block.py b/tests/tools/test_insert_block.py index cd6940f4916..99ec1600cc0 100644 --- a/tests/tools/test_insert_block.py +++ b/tests/tools/test_insert_block.py @@ -110,7 +110,7 @@ def test_mutually_exclusive_parameters_raise(coder_with_file): start_line="invalid_hashline", ) - assert result.startswith("Error:") + assert result.startswith("Error in InsertText:") assert "Hashline insertion failed" in result assert file_path.read_text().startswith("first line") coder.io.tool_error.assert_called() diff --git a/tests/tools/test_show_context.py b/tests/tools/test_show_context.py index 1dc0d194f18..dfedb021adb 100644 --- a/tests/tools/test_show_context.py +++ b/tests/tools/test_show_context.py @@ -63,7 +63,7 @@ def test_pattern_with_zero_line_number_is_allowed(coder_with_file): ) # show_numbered_context now returns a static success message - assert "Successfully retrieved most recent context" in result + assert "Successfully retrieved most recent contents for 1 file(s)" in result coder.io.tool_error.assert_not_called() @@ -83,7 +83,7 @@ def test_empty_pattern_uses_line_number(coder_with_file): ) # show_numbered_context now returns a static success message - assert "Successfully retrieved most recent context" in result + assert "Successfully retrieved most recent contents for 1 file(s)" in result coder.io.tool_error.assert_not_called()