From 7768614a6bbd4adb4e7153e31fcf8c31235d99ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Chindri=C8=99=20Mihai=20Alexandru?= <12643176+chindris-mihai-alexandru@users.noreply.github.com> Date: Wed, 26 Nov 2025 23:19:51 +0200 Subject: [PATCH 1/9] Replace Serper with Exa.ai for semantic web search Exa.ai provides AI-native neural search which offers significant advantages for research agents: - Semantic understanding: Finds relevant results based on meaning, not just keyword matching - Query optimization: Built-in autoprompt improves query quality - Direct content retrieval: Can fetch full page text in a single call - Better for complex queries: Neural embeddings excel at nuanced research questions This change simplifies the codebase by removing the dual search provider system and standardizing on Exa.ai. --- .env.example | 15 ++- inference/tool_search.py | 231 ++++++++++++++++++++++----------------- 2 files changed, 143 insertions(+), 103 deletions(-) diff --git a/.env.example b/.env.example index 8558e9c..f7bb08c 100644 --- a/.env.example +++ b/.env.example @@ -46,9 +46,14 @@ MAX_WORKERS=30 # API Keys and External Services # ============================================================================= -# Serper API for web search and Google Scholar -# Get your key from: https://serper.dev/ -SERPER_KEY_ID=your_key +# Exa.ai API for semantic web search +# Get your key from: https://exa.ai/ +# Exa provides AI-native neural search with: +# - Semantic understanding (not just keyword matching) +# - Built-in query optimization +# - Direct content retrieval +# - Better results for complex research queries +EXA_API_KEY=your_key # Jina API for web page reading # Get your key from: https://jina.ai/ @@ -57,8 +62,8 @@ JINA_API_KEYS=your_key # Summary model API (OpenAI-compatible) for page summarization # Get your key from: https://platform.openai.com/ API_KEY=your_key -API_BASE=your_api_base -SUMMARY_MODEL_NAME=your_summary_model_name +API_BASE=https://api.openai.com/v1 +SUMMARY_MODEL_NAME=gpt-4o-mini # Dashscope API for file parsing (PDF, Office, etc.) # Get your key from: https://dashscope.aliyun.com/ diff --git a/inference/tool_search.py b/inference/tool_search.py index 1a3f7b5..48869a4 100644 --- a/inference/tool_search.py +++ b/inference/tool_search.py @@ -1,131 +1,166 @@ +""" +Exa.ai Search Tool for DeepResearch +AI-native semantic search with neural embeddings for superior research results. + +Exa.ai advantages: +- Neural/semantic search (understands meaning, not just keywords) +- Can retrieve full page contents directly +- Better for research and complex queries +- Built-in query optimization (autoprompt) +- Supports date filtering and domain restrictions +""" + import json -from concurrent.futures import ThreadPoolExecutor -from typing import List, Union +import os +from typing import Any, Dict, List, Optional, Union import requests from qwen_agent.tools.base import BaseTool, register_tool -import asyncio -from typing import Dict, List, Optional, Union -import uuid -import http.client -import json - -import os - -SERPER_KEY=os.environ.get('SERPER_KEY_ID') +EXA_API_KEY = os.environ.get('EXA_API_KEY') +EXA_BASE_URL = "https://api.exa.ai" @register_tool("search", allow_overwrite=True) class Search(BaseTool): name = "search" - description = "Performs batched web searches: supply an array 'query'; the tool retrieves the top 10 results for each query in one call." + description = "Performs semantic web searches using Exa.ai: supply an array 'query'; retrieves top results with AI-powered understanding." parameters = { "type": "object", "properties": { "query": { "type": "array", - "items": { - "type": "string" - }, - "description": "Array of query strings. Include multiple complementary search queries in a single call." + "items": {"type": "string"}, + "description": "Array of query strings. Exa understands natural language queries well." }, + "num_results": { + "type": "integer", + "description": "Number of results per query (default: 10, max: 100)", + "default": 10 + }, + "include_contents": { + "type": "boolean", + "description": "Whether to include page text content", + "default": False + } }, "required": ["query"], } def __init__(self, cfg: Optional[dict] = None): super().__init__(cfg) - def google_search_with_serp(self, query: str): - def contains_chinese_basic(text: str) -> bool: - return any('\u4E00' <= char <= '\u9FFF' for char in text) - conn = http.client.HTTPSConnection("google.serper.dev") - if contains_chinese_basic(query): - payload = json.dumps({ - "q": query, - "location": "China", - "gl": "cn", - "hl": "zh-cn" - }) - - else: - payload = json.dumps({ - "q": query, - "location": "United States", - "gl": "us", - "hl": "en" - }) + self.api_key = EXA_API_KEY + if not self.api_key: + raise ValueError("EXA_API_KEY environment variable not set. Get your key from https://exa.ai/") + + def exa_search(self, query: str, num_results: int = 10, include_contents: bool = False) -> str: + """ + Perform a search using Exa.ai API. + + Exa supports multiple search types: + - "auto": Intelligently combines neural and other methods (default) + - "neural": AI-powered semantic search + - "deep": Comprehensive search with query expansion + """ headers = { - 'X-API-KEY': SERPER_KEY, - 'Content-Type': 'application/json' - } + "Content-Type": "application/json", + "x-api-key": self.api_key + } + + payload: Dict[str, Any] = { + "query": query, + "numResults": num_results, + "type": "auto", + "useAutoprompt": True, + } + if include_contents: + payload["contents"] = { + "text": {"maxCharacters": 2000} + } - for i in range(5): + response = None + for attempt in range(3): try: - conn.request("POST", "/search", payload, headers) - res = conn.getresponse() + response = requests.post( + f"{EXA_BASE_URL}/search", + headers=headers, + json=payload, + timeout=30 + ) + response.raise_for_status() break - except Exception as e: - print(e) - if i == 4: - return f"Google search Timeout, return None, Please try again later." + except requests.exceptions.RequestException as e: + if attempt == 2: + return f"Exa search failed after 3 attempts: {str(e)}" continue - - data = res.read() - results = json.loads(data.decode("utf-8")) - - try: - if "organic" not in results: - raise Exception(f"No results found for query: '{query}'. Use a less specific query.") - - web_snippets = list() - idx = 0 - if "organic" in results: - for page in results["organic"]: - idx += 1 - date_published = "" - if "date" in page: - date_published = "\nDate published: " + page["date"] - - source = "" - if "source" in page: - source = "\nSource: " + page["source"] - - snippet = "" - if "snippet" in page: - snippet = "\n" + page["snippet"] - - redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}" - redacted_version = redacted_version.replace("Your browser can't play this video.", "") - web_snippets.append(redacted_version) - - content = f"A Google search for '{query}' found {len(web_snippets)} results:\n\n## Web Results\n" + "\n\n".join(web_snippets) - return content - except: - return f"No results found for '{query}'. Try with a more general query." - - - - def search_with_serp(self, query: str): - result = self.google_search_with_serp(query) - return result + + if response is None: + return "Exa search failed: no response received" + + results = response.json() + + if "results" not in results or not results["results"]: + return f"No results found for '{query}'. Try a different query." + + web_snippets = [] + for idx, result in enumerate(results["results"], 1): + title = result.get("title", "No title") + url = result.get("url", "") + published_date = result.get("publishedDate", "") + + snippet_parts = [f"{idx}. [{title}]({url})"] + + if published_date: + snippet_parts.append(f"Date: {published_date[:10]}") + + if include_contents and "text" in result: + text = result["text"][:500] + snippet_parts.append(f"\n{text}...") + elif "snippet" in result: + snippet_parts.append(f"\n{result['snippet']}") + + web_snippets.append("\n".join(snippet_parts)) + + search_type = results.get("resolvedSearchType", "neural") + content = f"Exa {search_type} search for '{query}' found {len(web_snippets)} results:\n\n## Web Results\n\n" + "\n\n".join(web_snippets) + return content - def call(self, params: Union[str, dict], **kwargs) -> str: - try: - query = params["query"] - except: - return "[Search] Invalid request format: Input must be a JSON object containing 'query' field" + def call(self, params: Union[str, dict], **kwargs: Any) -> str: + params_dict: Dict[str, Any] + if isinstance(params, str): + try: + params_dict = json.loads(params) + except json.JSONDecodeError: + return "[Search] Invalid JSON input" + else: + params_dict = dict(params) + + query = params_dict.get("query") + if not query: + return "[Search] Invalid request: 'query' field is required" + + raw_num = params_dict.get("num_results", 10) + num_results = int(raw_num) if raw_num is not None else 10 + include_contents = bool(params_dict.get("include_contents", False)) if isinstance(query, str): - # 单个查询 - response = self.search_with_serp(query) - else: - # 多个查询 - assert isinstance(query, List) + return self.exa_search(query, num_results, include_contents) + + if isinstance(query, list): responses = [] for q in query: - responses.append(self.search_with_serp(q)) - response = "\n=======\n".join(responses) - - return response + responses.append(self.exa_search(q, num_results, include_contents)) + return "\n=======\n".join(responses) + + return "[Search] Invalid query format: must be string or array of strings" + +if __name__ == "__main__": + from dotenv import load_dotenv + + env_path = os.path.join(os.path.dirname(__file__), "..", ".env") + load_dotenv(env_path) + + searcher = Search() + result = searcher.call({"query": ["What is retrieval augmented generation?"]}) + print(result) From 63f5738df30c4c170a541dd6a7963cfac538bb49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Chindri=C8=99=20Mihai=20Alexandru?= <12643176+chindris-mihai-alexandru@users.noreply.github.com> Date: Wed, 26 Nov 2025 23:28:33 +0200 Subject: [PATCH 2/9] Add category filtering and AI highlights support to Exa search - Add category parameter to filter results (research paper, news, github, etc.) - Add AI-generated highlights for better content extraction - Include author information in search results - Document all available Exa categories in docstrings --- inference/tool_search.py | 70 ++++++++++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 10 deletions(-) diff --git a/inference/tool_search.py b/inference/tool_search.py index 48869a4..52737ea 100644 --- a/inference/tool_search.py +++ b/inference/tool_search.py @@ -8,6 +8,8 @@ - Better for research and complex queries - Built-in query optimization (autoprompt) - Supports date filtering and domain restrictions +- Category filtering (research papers, news, company info, etc.) +- AI-generated highlights for quick comprehension """ import json @@ -19,11 +21,17 @@ EXA_API_KEY = os.environ.get('EXA_API_KEY') EXA_BASE_URL = "https://api.exa.ai" +# Valid Exa categories for filtering results +VALID_CATEGORIES = [ + "company", "research paper", "news", "pdf", + "github", "tweet", "personal site", "linkedin profile" +] + @register_tool("search", allow_overwrite=True) class Search(BaseTool): name = "search" - description = "Performs semantic web searches using Exa.ai: supply an array 'query'; retrieves top results with AI-powered understanding." + description = "Performs semantic web searches using Exa.ai: supply an array 'query'; retrieves top results with AI-powered understanding. Supports category filtering for research papers, news, etc." parameters = { "type": "object", "properties": { @@ -39,8 +47,13 @@ class Search(BaseTool): }, "include_contents": { "type": "boolean", - "description": "Whether to include page text content", + "description": "Whether to include page text content and highlights", "default": False + }, + "category": { + "type": "string", + "description": "Filter by category: 'research paper', 'news', 'company', 'pdf', 'github', 'tweet', 'personal site', 'linkedin profile'", + "enum": ["company", "research paper", "news", "pdf", "github", "tweet", "personal site", "linkedin profile"] } }, "required": ["query"], @@ -52,7 +65,13 @@ def __init__(self, cfg: Optional[dict] = None): if not self.api_key: raise ValueError("EXA_API_KEY environment variable not set. Get your key from https://exa.ai/") - def exa_search(self, query: str, num_results: int = 10, include_contents: bool = False) -> str: + def exa_search( + self, + query: str, + num_results: int = 10, + include_contents: bool = False, + category: Optional[str] = None + ) -> str: """ Perform a search using Exa.ai API. @@ -60,6 +79,16 @@ def exa_search(self, query: str, num_results: int = 10, include_contents: bool = - "auto": Intelligently combines neural and other methods (default) - "neural": AI-powered semantic search - "deep": Comprehensive search with query expansion + + Categories available: + - "research paper": Academic papers and publications + - "news": News articles + - "company": Company websites and info + - "pdf": PDF documents + - "github": GitHub repositories + - "tweet": Twitter/X posts + - "personal site": Personal websites/blogs + - "linkedin profile": LinkedIn profiles """ headers = { "Content-Type": "application/json", @@ -73,9 +102,14 @@ def exa_search(self, query: str, num_results: int = 10, include_contents: bool = "useAutoprompt": True, } + # Add category filter if specified + if category and category in VALID_CATEGORIES: + payload["category"] = category + if include_contents: payload["contents"] = { - "text": {"maxCharacters": 2000} + "text": {"maxCharacters": 2000}, + "highlights": True } response = None @@ -107,22 +141,33 @@ def exa_search(self, query: str, num_results: int = 10, include_contents: bool = title = result.get("title", "No title") url = result.get("url", "") published_date = result.get("publishedDate", "") + author = result.get("author", "") snippet_parts = [f"{idx}. [{title}]({url})"] + if author: + snippet_parts.append(f"Author: {author}") if published_date: snippet_parts.append(f"Date: {published_date[:10]}") - if include_contents and "text" in result: - text = result["text"][:500] - snippet_parts.append(f"\n{text}...") + # Prefer highlights (AI-generated key points), then text, then snippet + if include_contents: + highlights = result.get("highlights", []) + if highlights: + snippet_parts.append("\nKey points:") + for h in highlights[:3]: + snippet_parts.append(f" • {h}") + elif "text" in result: + text = result["text"][:500] + snippet_parts.append(f"\n{text}...") elif "snippet" in result: snippet_parts.append(f"\n{result['snippet']}") web_snippets.append("\n".join(snippet_parts)) search_type = results.get("resolvedSearchType", "neural") - content = f"Exa {search_type} search for '{query}' found {len(web_snippets)} results:\n\n## Web Results\n\n" + "\n\n".join(web_snippets) + category_info = f" (category: {category})" if category else "" + content = f"Exa {search_type} search{category_info} for '{query}' found {len(web_snippets)} results:\n\n## Web Results\n\n" + "\n\n".join(web_snippets) return content def call(self, params: Union[str, dict], **kwargs: Any) -> str: @@ -142,14 +187,19 @@ def call(self, params: Union[str, dict], **kwargs: Any) -> str: raw_num = params_dict.get("num_results", 10) num_results = int(raw_num) if raw_num is not None else 10 include_contents = bool(params_dict.get("include_contents", False)) + category = params_dict.get("category") + + # Validate category if provided + if category and category not in VALID_CATEGORIES: + category = None if isinstance(query, str): - return self.exa_search(query, num_results, include_contents) + return self.exa_search(query, num_results, include_contents, category) if isinstance(query, list): responses = [] for q in query: - responses.append(self.exa_search(q, num_results, include_contents)) + responses.append(self.exa_search(q, num_results, include_contents, category)) return "\n=======\n".join(responses) return "[Search] Invalid query format: must be string or array of strings" From 8fda6ecdb55407dfc958bca7140b00bc0b4141f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Chindri=C8=99=20Mihai=20Alexandru?= <12643176+chindris-mihai-alexandru@users.noreply.github.com> Date: Wed, 26 Nov 2025 23:34:50 +0200 Subject: [PATCH 3/9] Polish Exa search: remove unused import, add rate limit and auth error handling --- inference/tool_search.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/inference/tool_search.py b/inference/tool_search.py index 52737ea..cb3b364 100644 --- a/inference/tool_search.py +++ b/inference/tool_search.py @@ -14,7 +14,7 @@ import json import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union import requests from qwen_agent.tools.base import BaseTool, register_tool @@ -123,6 +123,13 @@ def exa_search( ) response.raise_for_status() break + except requests.exceptions.HTTPError as e: + if response is not None and response.status_code == 429: + return f"Exa search rate limited. Please wait and try again." + if response is not None and response.status_code == 401: + return f"Exa API key invalid. Check your EXA_API_KEY environment variable." + if attempt == 2: + return f"Exa search failed after 3 attempts: {str(e)}" except requests.exceptions.RequestException as e: if attempt == 2: return f"Exa search failed after 3 attempts: {str(e)}" From 2d8d3887e9a98e1911c69bbff5808d520f3b2e83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Chindri=C8=99=20Mihai=20Alexandru?= <12643176+chindris-mihai-alexandru@users.noreply.github.com> Date: Thu, 27 Nov 2025 00:14:38 +0200 Subject: [PATCH 4/9] Add MLX support for Apple Silicon inference - Add run_mlx_react.py: React agent runner using MLX-lm server - Add run_mlx_infer.sh: Shell script to start MLX server and run inference - Add test_mlx_connection.py: Test script to verify MLX server connectivity - Update .env.example with MLX configuration options Enables running DeepResearch on Apple Silicon Macs (M1/M2/M3/M4) using the MLX framework instead of CUDA/vLLM. Uses the 4-bit quantized model (abalogh/Tongyi-DeepResearch-30B-A3B-4bit, ~17GB) which fits in 32GB RAM. Tested on M1 Max with 32GB RAM - model loads and inference works correctly. --- .env.example | 22 +- inference/run_mlx_infer.sh | 134 +++++++++ inference/run_mlx_react.py | 479 +++++++++++++++++++++++++++++++ inference/test_mlx_connection.py | 85 ++++++ 4 files changed, 719 insertions(+), 1 deletion(-) create mode 100755 inference/run_mlx_infer.sh create mode 100644 inference/run_mlx_react.py create mode 100644 inference/test_mlx_connection.py diff --git a/.env.example b/.env.example index f7bb08c..43e38c9 100644 --- a/.env.example +++ b/.env.example @@ -100,4 +100,24 @@ IDP_KEY_SECRET=your_idp_key_secret # These are typically set by distributed training frameworks # WORLD_SIZE=1 -# RANK=0 \ No newline at end of file +# RANK=0 + +# ============================================================================= +# MLX Configuration (Apple Silicon Only) +# ============================================================================= +# For running on Apple Silicon Macs (M1/M2/M3/M4) using MLX framework +# instead of CUDA/vLLM. Uses mlx-lm for efficient local inference. +# +# Requirements: +# pip install mlx-lm +# +# Recommended models: +# - abalogh/Tongyi-DeepResearch-30B-A3B-4bit (17GB, fits 32GB RAM) +# - Original BF16 model requires 62GB+ +# +# Usage: +# bash inference/run_mlx_infer.sh +# +# MLX_MODEL=abalogh/Tongyi-DeepResearch-30B-A3B-4bit +# MLX_HOST=127.0.0.1 +# MLX_PORT=8080 \ No newline at end of file diff --git a/inference/run_mlx_infer.sh b/inference/run_mlx_infer.sh new file mode 100755 index 0000000..3ad9823 --- /dev/null +++ b/inference/run_mlx_infer.sh @@ -0,0 +1,134 @@ +#!/bin/bash +# MLX Inference Script for Apple Silicon (M1/M2/M3/M4) +# This script runs DeepResearch using Apple's MLX framework instead of vLLM/CUDA + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$SCRIPT_DIR/.." +ENV_FILE="$PROJECT_ROOT/.env" +VENV_PATH="$PROJECT_ROOT/.venv" + +# Activate virtual environment if it exists +if [ -d "$VENV_PATH" ]; then + echo "Activating virtual environment..." + source "$VENV_PATH/bin/activate" +else + echo "Warning: No .venv found at $VENV_PATH" + echo "Make sure mlx-lm is installed: pip install mlx-lm" +fi + +# Load environment variables +if [ ! -f "$ENV_FILE" ]; then + echo "Error: .env file not found at $ENV_FILE" + echo "Please copy .env.example to .env and configure your settings" + exit 1 +fi + +echo "Loading environment variables..." +set -a +source "$ENV_FILE" +set +a + +# MLX-specific configuration +MLX_MODEL="${MLX_MODEL:-abalogh/Tongyi-DeepResearch-30B-A3B-4bit}" +MLX_PORT="${MLX_PORT:-8080}" +MLX_HOST="${MLX_HOST:-127.0.0.1}" + +# Default inference parameters for MLX +TEMPERATURE="${TEMPERATURE:-0.85}" +MAX_TOKENS="${MAX_TOKENS:-10000}" + +echo "============================================" +echo "DeepResearch MLX Inference (Apple Silicon)" +echo "============================================" +echo "Model: $MLX_MODEL" +echo "Server: http://$MLX_HOST:$MLX_PORT" +echo "Temperature: $TEMPERATURE" +echo "============================================" + +# Check if mlx-lm is installed +if ! command -v mlx_lm.server &> /dev/null; then + echo "Error: mlx-lm not installed. Install with: pip install mlx-lm" + exit 1 +fi + +###################################### +### 1. Start MLX Server ### +###################################### + +echo "Starting MLX server..." +echo "Note: First run will download the model (~17GB for 4-bit version)" + +# Kill any existing MLX server on the port +lsof -ti:$MLX_PORT | xargs kill -9 2>/dev/null || true + +# Start MLX server in background +mlx_lm.server \ + --model "$MLX_MODEL" \ + --host "$MLX_HOST" \ + --port "$MLX_PORT" \ + --temp "$TEMPERATURE" \ + --max-tokens "$MAX_TOKENS" \ + --trust-remote-code \ + --log-level INFO \ + --use-default-chat-template & + +MLX_PID=$! +echo "MLX server started with PID: $MLX_PID" + +# Trap to cleanup on exit +cleanup() { + echo "Shutting down MLX server..." + kill $MLX_PID 2>/dev/null || true + exit 0 +} +trap cleanup SIGINT SIGTERM EXIT + +###################################### +### 2. Wait for server to be ready ### +###################################### + +echo "Waiting for MLX server to be ready..." +timeout=600 # 10 minutes (model download may take time) +start_time=$(date +%s) + +while true; do + if curl -s -f "http://$MLX_HOST:$MLX_PORT/v1/models" > /dev/null 2>&1; then + echo "MLX server is ready!" + break + fi + + current_time=$(date +%s) + elapsed=$((current_time - start_time)) + + if [ $elapsed -gt $timeout ]; then + echo "Error: MLX server startup timeout after ${timeout} seconds" + exit 1 + fi + + echo -n "." + sleep 5 +done + +###################################### +### 3. Run Inference ### +###################################### + +echo "" +echo "==== Starting inference ====" + +cd "$SCRIPT_DIR" + +# Use MLX-specific react agent script +python -u run_mlx_react.py \ + --dataset "${DATASET:-$PROJECT_ROOT/eval_data/sample_questions.jsonl}" \ + --output "${OUTPUT_PATH:-./outputs}" \ + --max_workers "${MAX_WORKERS:-1}" \ + --model "$MLX_MODEL" \ + --mlx_port "$MLX_PORT" \ + --temperature "$TEMPERATURE" \ + --presence_penalty "${PRESENCE_PENALTY:-1.1}" \ + --roll_out_count "${ROLLOUT_COUNT:-1}" + +echo "Inference complete!" diff --git a/inference/run_mlx_react.py b/inference/run_mlx_react.py new file mode 100644 index 0000000..1f982ba --- /dev/null +++ b/inference/run_mlx_react.py @@ -0,0 +1,479 @@ +""" +MLX React Agent Runner for Apple Silicon + +This script runs DeepResearch using Apple's MLX framework instead of vLLM/CUDA. +It connects to the MLX-lm server which provides an OpenAI-compatible API. + +Usage: + python run_mlx_react.py --dataset eval_data/test.jsonl --output ./outputs +""" + +import argparse +import json +import os +import time +import threading +import random +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from tqdm import tqdm + +import json5 +from openai import OpenAI, APIError, APIConnectionError, APITimeoutError + +from prompt import SYSTEM_PROMPT + +# Tool registry - import tools with fallbacks for compatibility +TOOL_MAP = {} + +try: + from tool_search import Search + TOOL_MAP['search'] = Search() +except ImportError as e: + print(f"Warning: Could not import Search tool: {e}") + +try: + from tool_visit import Visit + TOOL_MAP['visit'] = Visit() +except ImportError as e: + print(f"Warning: Could not import Visit tool: {e}") + +try: + from tool_scholar import Scholar + TOOL_MAP['google_scholar'] = Scholar() +except ImportError as e: + print(f"Warning: Could not import Scholar tool: {e}") + +try: + from tool_file import FileParser + TOOL_MAP['parse_file'] = FileParser() +except ImportError as e: + print(f"Warning: Could not import FileParser tool: {e}") + +try: + from tool_python import PythonInterpreter + TOOL_MAP['PythonInterpreter'] = PythonInterpreter() +except (ImportError, Exception) as e: + print(f"Warning: Could not import PythonInterpreter tool: {e}") + +print(f"Loaded tools: {list(TOOL_MAP.keys())}") + +MAX_LLM_CALL_PER_RUN = int(os.getenv('MAX_LLM_CALL_PER_RUN', 100)) + + +def today_date(): + return datetime.now().strftime("%Y-%m-%d") + + +class MLXReactAgent: + """ + React agent that uses MLX-lm server for inference on Apple Silicon. + + The MLX server provides an OpenAI-compatible API at /v1/chat/completions, + making it a drop-in replacement for vLLM/sglang servers. + """ + + def __init__(self, model: str, mlx_host: str = "127.0.0.1", mlx_port: int = 8080, + temperature: float = 0.85, top_p: float = 0.95, + presence_penalty: float = 1.1, max_tokens: int = 10000): + self.model = model + self.mlx_host = mlx_host + self.mlx_port = mlx_port + self.temperature = temperature + self.top_p = top_p + self.presence_penalty = presence_penalty + self.max_tokens = max_tokens + + # Create OpenAI client pointing to MLX server + self.client = OpenAI( + api_key="mlx-local", # MLX server doesn't require auth + base_url=f"http://{mlx_host}:{mlx_port}/v1", + timeout=600.0, + ) + + # Verify connection + self._verify_connection() + + def _verify_connection(self): + """Verify that the MLX server is running and accessible.""" + try: + models = self.client.models.list() + available = [m.id for m in models.data] + print(f"MLX server connected. Available models: {available}") + except Exception as e: + raise ConnectionError(f"Cannot connect to MLX server at {self.mlx_host}:{self.mlx_port}: {e}") + + def call_server(self, messages: list, max_tries: int = 10) -> str: + """Call the MLX server with exponential backoff retry.""" + base_sleep = 1 + + for attempt in range(max_tries): + try: + print(f"--- MLX call attempt {attempt + 1}/{max_tries} ---") + + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + stop=["\n", ""], + temperature=self.temperature, + top_p=self.top_p, + max_tokens=self.max_tokens, + # Note: MLX-lm may not support all parameters + # presence_penalty=self.presence_penalty, + ) + + content = response.choices[0].message.content + + if content and content.strip(): + print("--- MLX call successful ---") + return content.strip() + + print(f"Warning: Attempt {attempt + 1} received empty response") + + except (APIError, APIConnectionError, APITimeoutError) as e: + print(f"API error on attempt {attempt + 1}: {e}") + except Exception as e: + print(f"Unexpected error on attempt {attempt + 1}: {e}") + + if attempt < max_tries - 1: + sleep_time = min(base_sleep * (2 ** attempt) + random.uniform(0, 1), 30) + print(f"Retrying in {sleep_time:.2f}s...") + time.sleep(sleep_time) + + return "MLX server error - all retries exhausted" + + def estimate_tokens(self, messages: list) -> int: + """ + Rough token estimation without loading a full tokenizer. + MLX models typically use ~4 chars per token for English text. + """ + total_chars = sum(len(m.get("content", "")) for m in messages) + return total_chars // 4 + + def custom_call_tool(self, tool_name: str, tool_args: dict) -> str: + """Execute a tool and return the result.""" + if tool_name not in TOOL_MAP: + return f"Error: Tool {tool_name} not found" + + tool_args["params"] = tool_args + + if "python" in tool_name.lower(): + return TOOL_MAP['PythonInterpreter'].call(tool_args) + + if tool_name == "parse_file": + import asyncio + params = {"files": tool_args["files"]} + result = asyncio.run(TOOL_MAP[tool_name].call(params, file_root_path="./eval_data/file_corpus")) + return str(result) if not isinstance(result, str) else result + + return TOOL_MAP[tool_name].call(tool_args) + + def run(self, data: dict) -> dict: + """ + Run the react agent loop for a single question. + + Args: + data: Dict with 'item' containing 'question' and optionally 'answer' + + Returns: + Dict with question, answer, messages, prediction, and termination status + """ + # Extract question + item = data['item'] + question = item.get('question', '') + if not question: + try: + raw_msg = item['messages'][1]["content"] + question = raw_msg.split("User:")[1].strip() if "User:" in raw_msg else raw_msg + except Exception as e: + print(f"Failed to extract question: {e}") + return {"question": "", "error": "Could not extract question"} + + answer = item.get('answer', '') + start_time = time.time() + + # Build initial messages + system_prompt = SYSTEM_PROMPT + str(today_date()) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question} + ] + + num_calls_remaining = MAX_LLM_CALL_PER_RUN + round_num = 0 + max_context_tokens = 110 * 1024 # ~110K tokens max + timeout_minutes = 150 # 2.5 hours + + while num_calls_remaining > 0: + # Check timeout + elapsed = time.time() - start_time + if elapsed > timeout_minutes * 60: + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": "No answer found after timeout", + "termination": "timeout" + } + + round_num += 1 + num_calls_remaining -= 1 + + # Call MLX server + content = self.call_server(messages) + print(f"Round {round_num}: {content[:200]}..." if len(content) > 200 else f"Round {round_num}: {content}") + + # Clean up response + if '' in content: + content = content[:content.find('')] + + messages.append({"role": "assistant", "content": content.strip()}) + + # Check for tool calls + if '' in content and '' in content: + tool_call_str = content.split('')[1].split('')[0] + + try: + if "python" in tool_call_str.lower(): + try: + code = content.split('')[1].split('')[0] + code = code.split('')[1].split('')[0].strip() + result = TOOL_MAP['PythonInterpreter'].call(code) + except: + result = "[Python Interpreter Error]: Formatting error." + else: + tool_call = json5.loads(tool_call_str) + tool_name = tool_call.get('name', '') + tool_args = tool_call.get('arguments', {}) + result = self.custom_call_tool(tool_name, tool_args) + except Exception as e: + result = f'Error: Tool call is not valid JSON. Must contain "name" and "arguments" fields. Error: {e}' + + result = f"\n{result}\n" + messages.append({"role": "user", "content": result}) + + # Check for final answer + if '' in content and '' in content: + prediction = content.split('')[1].split('')[0] + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction, + "termination": "answer" + } + + # Check token limit + token_count = self.estimate_tokens(messages) + print(f"Round {round_num}, estimated tokens: {token_count}") + + if token_count > max_context_tokens: + print(f"Token limit exceeded: {token_count} > {max_context_tokens}") + + # Force final answer + messages[-1]['content'] = ( + "You have reached the maximum context length. Stop making tool calls and " + "provide your best answer based on all information above in this format:\n" + "your final thinking\nyour answer" + ) + + content = self.call_server(messages) + messages.append({"role": "assistant", "content": content.strip()}) + + if '' in content and '' in content: + prediction = content.split('')[1].split('')[0] + termination = "token_limit_answer" + else: + prediction = content + termination = "token_limit_format_error" + + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction, + "termination": termination + } + + if num_calls_remaining <= 0: + messages[-1]['content'] = "Maximum LLM calls reached." + + # No answer found + if '' in messages[-1]['content']: + prediction = messages[-1]['content'].split('')[1].split('')[0] + termination = "answer" + else: + prediction = "No answer found." + termination = "calls_exhausted" + + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction, + "termination": termination + } + + +def main(): + parser = argparse.ArgumentParser(description="Run DeepResearch with MLX on Apple Silicon") + parser.add_argument("--model", type=str, default="abalogh/Tongyi-DeepResearch-30B-A3B-4bit", + help="Model name (should match MLX server)") + parser.add_argument("--dataset", type=str, required=True, + help="Path to input dataset (JSON or JSONL)") + parser.add_argument("--output", type=str, default="./outputs", + help="Output directory") + parser.add_argument("--mlx_host", type=str, default="127.0.0.1", + help="MLX server host") + parser.add_argument("--mlx_port", type=int, default=8080, + help="MLX server port") + parser.add_argument("--temperature", type=float, default=0.85, + help="Sampling temperature") + parser.add_argument("--top_p", type=float, default=0.95, + help="Top-p sampling") + parser.add_argument("--presence_penalty", type=float, default=1.1, + help="Presence penalty") + parser.add_argument("--max_workers", type=int, default=1, + help="Number of parallel workers (keep at 1 for MLX)") + parser.add_argument("--roll_out_count", type=int, default=1, + help="Number of rollouts per question") + args = parser.parse_args() + + # Setup output directory + model_name = os.path.basename(args.model.rstrip('/')) + model_dir = os.path.join(args.output, f"{model_name}_mlx") + dataset_name = os.path.splitext(os.path.basename(args.dataset))[0] + output_dir = os.path.join(model_dir, dataset_name) + os.makedirs(output_dir, exist_ok=True) + + print("=" * 50) + print("DeepResearch MLX Inference") + print("=" * 50) + print(f"Model: {args.model}") + print(f"Dataset: {args.dataset}") + print(f"Output: {output_dir}") + print(f"MLX Server: http://{args.mlx_host}:{args.mlx_port}") + print(f"Temperature: {args.temperature}") + print(f"Rollouts: {args.roll_out_count}") + print("=" * 50) + + # Load dataset + try: + if args.dataset.endswith(".json"): + with open(args.dataset, "r", encoding="utf-8") as f: + items = json.load(f) + elif args.dataset.endswith(".jsonl"): + with open(args.dataset, "r", encoding="utf-8") as f: + items = [json.loads(line) for line in f] + else: + raise ValueError("Dataset must be .json or .jsonl") + except FileNotFoundError: + print(f"Error: Dataset not found at {args.dataset}") + return + except Exception as e: + print(f"Error loading dataset: {e}") + return + + print(f"Loaded {len(items)} items from dataset") + + # Initialize agent + try: + agent = MLXReactAgent( + model=args.model, + mlx_host=args.mlx_host, + mlx_port=args.mlx_port, + temperature=args.temperature, + top_p=args.top_p, + presence_penalty=args.presence_penalty, + ) + except ConnectionError as e: + print(f"Error: {e}") + print("Make sure the MLX server is running:") + print(f" mlx_lm.server --model {args.model} --port {args.mlx_port}") + return + + # Setup output files per rollout + output_files = { + i: os.path.join(output_dir, f"iter{i}.jsonl") + for i in range(1, args.roll_out_count + 1) + } + + # Load already processed questions + processed_per_rollout = {} + for rollout_idx in range(1, args.roll_out_count + 1): + processed = set() + output_file = output_files[rollout_idx] + if os.path.exists(output_file): + with open(output_file, "r", encoding="utf-8") as f: + for line in f: + try: + data = json.loads(line) + if "question" in data and "error" not in data: + processed.add(data["question"].strip()) + except json.JSONDecodeError: + pass + processed_per_rollout[rollout_idx] = processed + print(f"Rollout {rollout_idx}: {len(processed)} already processed") + + # Build task list + tasks = [] + for rollout_idx in range(1, args.roll_out_count + 1): + processed = processed_per_rollout[rollout_idx] + for item in items: + question = item.get("question", "").strip() + if not question: + try: + user_msg = item["messages"][1]["content"] + question = user_msg.split("User:")[1].strip() if "User:" in user_msg else user_msg + item["question"] = question + except: + continue + + if question and question not in processed: + tasks.append({ + "item": item.copy(), + "rollout_idx": rollout_idx, + }) + + print(f"Tasks to run: {len(tasks)}") + + if not tasks: + print("All tasks already completed!") + return + + # Run tasks + # Note: MLX is single-threaded on GPU, so max_workers=1 is recommended + write_locks = {i: threading.Lock() for i in range(1, args.roll_out_count + 1)} + + for task in tqdm(tasks, desc="Processing"): + rollout_idx = task["rollout_idx"] + output_file = output_files[rollout_idx] + + try: + result = agent.run(task) + result["rollout_idx"] = rollout_idx + + with write_locks[rollout_idx]: + with open(output_file, "a", encoding="utf-8") as f: + f.write(json.dumps(result, ensure_ascii=False) + "\n") + + except Exception as e: + print(f"Error processing task: {e}") + error_result = { + "question": task["item"].get("question", ""), + "answer": task["item"].get("answer", ""), + "rollout_idx": rollout_idx, + "error": str(e), + "messages": [], + "prediction": "[Failed]" + } + with write_locks[rollout_idx]: + with open(output_file, "a", encoding="utf-8") as f: + f.write(json.dumps(error_result, ensure_ascii=False) + "\n") + + print("\nInference complete!") + print(f"Results saved to: {output_dir}") + + +if __name__ == "__main__": + main() diff --git a/inference/test_mlx_connection.py b/inference/test_mlx_connection.py new file mode 100644 index 0000000..3064a7d --- /dev/null +++ b/inference/test_mlx_connection.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +""" +Quick test script to verify MLX server connection and basic inference. +Run this after starting the MLX server to verify everything works. + +Usage: + # Terminal 1: Start MLX server + mlx_lm.server --model abalogh/Tongyi-DeepResearch-30B-A3B-4bit --port 8080 + + # Terminal 2: Run this test + python test_mlx_connection.py +""" + +import sys +from openai import OpenAI + +MLX_HOST = "127.0.0.1" +MLX_PORT = 8080 + + +def test_connection(): + """Test basic connection to MLX server.""" + print(f"Testing connection to MLX server at {MLX_HOST}:{MLX_PORT}...") + + client = OpenAI( + api_key="mlx-local", + base_url=f"http://{MLX_HOST}:{MLX_PORT}/v1", + timeout=60.0, + ) + + # Test 1: List models + print("\n1. Listing available models...") + try: + models = client.models.list() + available = [m.id for m in models.data] + print(f" Available models: {available}") + except Exception as e: + print(f" FAILED: {e}") + return False + + # Test 2: Simple completion + print("\n2. Testing simple completion...") + try: + response = client.chat.completions.create( + model=available[0] if available else "default", + messages=[ + {"role": "user", "content": "What is 2+2? Answer with just the number."} + ], + max_tokens=10, + temperature=0.1, + ) + answer = response.choices[0].message.content + print(f" Response: {answer}") + except Exception as e: + print(f" FAILED: {e}") + return False + + # Test 3: Test with system prompt (like DeepResearch uses) + print("\n3. Testing with system prompt...") + try: + response = client.chat.completions.create( + model=available[0] if available else "default", + messages=[ + {"role": "system", "content": "You are a helpful research assistant. Think step by step."}, + {"role": "user", "content": "What is the capital of Japan?"} + ], + max_tokens=100, + temperature=0.7, + ) + answer = response.choices[0].message.content or "" + print(f" Response: {answer[:200]}..." if len(answer) > 200 else f" Response: {answer}") + except Exception as e: + print(f" FAILED: {e}") + return False + + print("\n" + "=" * 50) + print("All tests passed! MLX server is working correctly.") + print("You can now run: bash inference/run_mlx_infer.sh") + print("=" * 50) + return True + + +if __name__ == "__main__": + success = test_connection() + sys.exit(0 if success else 1) From 7e93ea97de19d8ac6959906cc1f264931e3883cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Chindri=C8=99=20Mihai=20Alexandru?= <12643176+chindris-mihai-alexandru@users.noreply.github.com> Date: Thu, 27 Nov 2025 00:48:08 +0200 Subject: [PATCH 5/9] fix: use native MLX API instead of server for proper chat template handling The MLX OpenAI-compatible server was not applying the chat template correctly, causing tool calls to fail. This change: - Uses native mlx_lm.load() and mlx_lm.generate() Python API directly - Builds chat prompts with proper Qwen format (<|im_start|>...<|im_end|>) - Uses make_sampler() for temperature/top_p settings (new mlx-lm API) - Removes server dependency - model is loaded directly in Python - Adds test_mlx_tool_loop.py for debugging tool call issues - Simplifies run_mlx_infer.sh (no server startup needed) Tested successfully: model now generates proper tags and the agent loop executes search/visit tools correctly. --- inference/run_mlx_infer.sh | 95 +++---------- inference/run_mlx_react.py | 243 +++++++++++++++----------------- inference/test_mlx_tool_loop.py | 177 +++++++++++++++++++++++ 3 files changed, 305 insertions(+), 210 deletions(-) create mode 100644 inference/test_mlx_tool_loop.py diff --git a/inference/run_mlx_infer.sh b/inference/run_mlx_infer.sh index 3ad9823..dd45fd1 100755 --- a/inference/run_mlx_infer.sh +++ b/inference/run_mlx_infer.sh @@ -1,21 +1,25 @@ #!/bin/bash # MLX Inference Script for Apple Silicon (M1/M2/M3/M4) # This script runs DeepResearch using Apple's MLX framework instead of vLLM/CUDA +# +# Uses native MLX Python API (no separate server needed) set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PROJECT_ROOT="$SCRIPT_DIR/.." ENV_FILE="$PROJECT_ROOT/.env" -VENV_PATH="$PROJECT_ROOT/.venv" +VENV_PATH="$PROJECT_ROOT/venv" # Activate virtual environment if it exists if [ -d "$VENV_PATH" ]; then echo "Activating virtual environment..." source "$VENV_PATH/bin/activate" else - echo "Warning: No .venv found at $VENV_PATH" - echo "Make sure mlx-lm is installed: pip install mlx-lm" + echo "Warning: No venv found at $VENV_PATH" + echo "Create one with: python3 -m venv $VENV_PATH" + echo "Then install: pip install mlx-lm python-dotenv requests json5 tqdm qwen-agent" + exit 1 fi # Load environment variables @@ -32,103 +36,40 @@ set +a # MLX-specific configuration MLX_MODEL="${MLX_MODEL:-abalogh/Tongyi-DeepResearch-30B-A3B-4bit}" -MLX_PORT="${MLX_PORT:-8080}" -MLX_HOST="${MLX_HOST:-127.0.0.1}" -# Default inference parameters for MLX +# Default inference parameters TEMPERATURE="${TEMPERATURE:-0.85}" -MAX_TOKENS="${MAX_TOKENS:-10000}" +MAX_TOKENS="${MAX_TOKENS:-8192}" +TOP_P="${TOP_P:-0.95}" echo "============================================" echo "DeepResearch MLX Inference (Apple Silicon)" echo "============================================" echo "Model: $MLX_MODEL" -echo "Server: http://$MLX_HOST:$MLX_PORT" echo "Temperature: $TEMPERATURE" +echo "Top-P: $TOP_P" +echo "Max Tokens: $MAX_TOKENS" echo "============================================" # Check if mlx-lm is installed -if ! command -v mlx_lm.server &> /dev/null; then +python -c "import mlx_lm" 2>/dev/null || { echo "Error: mlx-lm not installed. Install with: pip install mlx-lm" exit 1 -fi - -###################################### -### 1. Start MLX Server ### -###################################### - -echo "Starting MLX server..." -echo "Note: First run will download the model (~17GB for 4-bit version)" - -# Kill any existing MLX server on the port -lsof -ti:$MLX_PORT | xargs kill -9 2>/dev/null || true - -# Start MLX server in background -mlx_lm.server \ - --model "$MLX_MODEL" \ - --host "$MLX_HOST" \ - --port "$MLX_PORT" \ - --temp "$TEMPERATURE" \ - --max-tokens "$MAX_TOKENS" \ - --trust-remote-code \ - --log-level INFO \ - --use-default-chat-template & - -MLX_PID=$! -echo "MLX server started with PID: $MLX_PID" - -# Trap to cleanup on exit -cleanup() { - echo "Shutting down MLX server..." - kill $MLX_PID 2>/dev/null || true - exit 0 } -trap cleanup SIGINT SIGTERM EXIT - -###################################### -### 2. Wait for server to be ready ### -###################################### - -echo "Waiting for MLX server to be ready..." -timeout=600 # 10 minutes (model download may take time) -start_time=$(date +%s) - -while true; do - if curl -s -f "http://$MLX_HOST:$MLX_PORT/v1/models" > /dev/null 2>&1; then - echo "MLX server is ready!" - break - fi - - current_time=$(date +%s) - elapsed=$((current_time - start_time)) - - if [ $elapsed -gt $timeout ]; then - echo "Error: MLX server startup timeout after ${timeout} seconds" - exit 1 - fi - - echo -n "." - sleep 5 -done - -###################################### -### 3. Run Inference ### -###################################### -echo "" -echo "==== Starting inference ====" +# Disable tokenizer parallelism warning +export TOKENIZERS_PARALLELISM=false +# Run inference using native MLX API (no server needed) cd "$SCRIPT_DIR" -# Use MLX-specific react agent script python -u run_mlx_react.py \ --dataset "${DATASET:-$PROJECT_ROOT/eval_data/sample_questions.jsonl}" \ --output "${OUTPUT_PATH:-./outputs}" \ - --max_workers "${MAX_WORKERS:-1}" \ --model "$MLX_MODEL" \ - --mlx_port "$MLX_PORT" \ --temperature "$TEMPERATURE" \ - --presence_penalty "${PRESENCE_PENALTY:-1.1}" \ + --top_p "$TOP_P" \ + --max_tokens "$MAX_TOKENS" \ --roll_out_count "${ROLLOUT_COUNT:-1}" echo "Inference complete!" diff --git a/inference/run_mlx_react.py b/inference/run_mlx_react.py index 1f982ba..1ef0ffa 100644 --- a/inference/run_mlx_react.py +++ b/inference/run_mlx_react.py @@ -2,7 +2,7 @@ MLX React Agent Runner for Apple Silicon This script runs DeepResearch using Apple's MLX framework instead of vLLM/CUDA. -It connects to the MLX-lm server which provides an OpenAI-compatible API. +Uses native MLX Python API with proper chat template handling. Usage: python run_mlx_react.py --dataset eval_data/test.jsonl --output ./outputs @@ -13,18 +13,22 @@ import os import time import threading -import random -from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime -from tqdm import tqdm +from typing import Any, Dict, List, Optional + +# Load environment variables +from dotenv import load_dotenv +load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env")) +from tqdm import tqdm import json5 -from openai import OpenAI, APIError, APIConnectionError, APITimeoutError +from mlx_lm import load, generate +from mlx_lm.sample_utils import make_sampler from prompt import SYSTEM_PROMPT # Tool registry - import tools with fallbacks for compatibility -TOOL_MAP = {} +TOOL_MAP: Dict[str, Any] = {} try: from tool_search import Search @@ -61,104 +65,83 @@ MAX_LLM_CALL_PER_RUN = int(os.getenv('MAX_LLM_CALL_PER_RUN', 100)) -def today_date(): +def today_date() -> str: return datetime.now().strftime("%Y-%m-%d") class MLXReactAgent: """ - React agent that uses MLX-lm server for inference on Apple Silicon. + React agent that uses native MLX Python API for inference on Apple Silicon. - The MLX server provides an OpenAI-compatible API at /v1/chat/completions, - making it a drop-in replacement for vLLM/sglang servers. + Uses the model's chat template directly for proper tool-calling format. """ - def __init__(self, model: str, mlx_host: str = "127.0.0.1", mlx_port: int = 8080, - temperature: float = 0.85, top_p: float = 0.95, - presence_penalty: float = 1.1, max_tokens: int = 10000): - self.model = model - self.mlx_host = mlx_host - self.mlx_port = mlx_port + def __init__(self, model_path: str, temperature: float = 0.85, + top_p: float = 0.95, max_tokens: int = 8192): + self.model_path = model_path self.temperature = temperature self.top_p = top_p - self.presence_penalty = presence_penalty self.max_tokens = max_tokens - # Create OpenAI client pointing to MLX server - self.client = OpenAI( - api_key="mlx-local", # MLX server doesn't require auth - base_url=f"http://{mlx_host}:{mlx_port}/v1", - timeout=600.0, - ) - - # Verify connection - self._verify_connection() + print(f"Loading model: {model_path}") + self.model, self.tokenizer = load(model_path) + print("Model loaded successfully") - def _verify_connection(self): - """Verify that the MLX server is running and accessible.""" - try: - models = self.client.models.list() - available = [m.id for m in models.data] - print(f"MLX server connected. Available models: {available}") - except Exception as e: - raise ConnectionError(f"Cannot connect to MLX server at {self.mlx_host}:{self.mlx_port}: {e}") + def build_prompt(self, messages: List[Dict[str, str]]) -> str: + """ + Build prompt using the Qwen chat template format. + Format: <|im_start|>role\ncontent<|im_end|> + """ + prompt_parts = [] + for msg in messages: + role = msg["role"] + content = msg["content"] + prompt_parts.append(f"<|im_start|>{role}\n{content}<|im_end|>") + + # Add assistant start token for generation + prompt_parts.append("<|im_start|>assistant\n") + return "\n".join(prompt_parts) - def call_server(self, messages: list, max_tries: int = 10) -> str: - """Call the MLX server with exponential backoff retry.""" - base_sleep = 1 + def generate_response(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str: + """Generate response using native MLX API.""" + prompt = self.build_prompt(messages) + tokens = max_tokens or self.max_tokens - for attempt in range(max_tries): - try: - print(f"--- MLX call attempt {attempt + 1}/{max_tries} ---") - - response = self.client.chat.completions.create( - model=self.model, - messages=messages, - stop=["\n", ""], - temperature=self.temperature, - top_p=self.top_p, - max_tokens=self.max_tokens, - # Note: MLX-lm may not support all parameters - # presence_penalty=self.presence_penalty, - ) - - content = response.choices[0].message.content - - if content and content.strip(): - print("--- MLX call successful ---") - return content.strip() - - print(f"Warning: Attempt {attempt + 1} received empty response") - - except (APIError, APIConnectionError, APITimeoutError) as e: - print(f"API error on attempt {attempt + 1}: {e}") - except Exception as e: - print(f"Unexpected error on attempt {attempt + 1}: {e}") - - if attempt < max_tries - 1: - sleep_time = min(base_sleep * (2 ** attempt) + random.uniform(0, 1), 30) - print(f"Retrying in {sleep_time:.2f}s...") - time.sleep(sleep_time) + # Create sampler with temperature and top_p + sampler = make_sampler(temp=self.temperature, top_p=self.top_p) + + # Generate with sampler + response = generate( + self.model, + self.tokenizer, + prompt=prompt, + max_tokens=tokens, + sampler=sampler, + verbose=False, + ) + + # Clean up response - remove trailing tokens + if "<|im_end|>" in response: + response = response.split("<|im_end|>")[0] + if "" in response: + response = response.split("")[0] - return "MLX server error - all retries exhausted" + return response.strip() - def estimate_tokens(self, messages: list) -> int: - """ - Rough token estimation without loading a full tokenizer. - MLX models typically use ~4 chars per token for English text. - """ + def estimate_tokens(self, messages: List[Dict[str, str]]) -> int: + """Rough token estimation.""" total_chars = sum(len(m.get("content", "")) for m in messages) return total_chars // 4 - def custom_call_tool(self, tool_name: str, tool_args: dict) -> str: + def custom_call_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> str: """Execute a tool and return the result.""" if tool_name not in TOOL_MAP: - return f"Error: Tool {tool_name} not found" + return f"Error: Tool {tool_name} not found. Available: {list(TOOL_MAP.keys())}" tool_args["params"] = tool_args if "python" in tool_name.lower(): - return TOOL_MAP['PythonInterpreter'].call(tool_args) + return str(TOOL_MAP['PythonInterpreter'].call(tool_args)) if tool_name == "parse_file": import asyncio @@ -166,9 +149,9 @@ def custom_call_tool(self, tool_name: str, tool_args: dict) -> str: result = asyncio.run(TOOL_MAP[tool_name].call(params, file_root_path="./eval_data/file_corpus")) return str(result) if not isinstance(result, str) else result - return TOOL_MAP[tool_name].call(tool_args) + return str(TOOL_MAP[tool_name].call(tool_args)) - def run(self, data: dict) -> dict: + def run(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Run the react agent loop for a single question. @@ -194,7 +177,7 @@ def run(self, data: dict) -> dict: # Build initial messages system_prompt = SYSTEM_PROMPT + str(today_date()) - messages = [ + messages: List[Dict[str, str]] = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": question} ] @@ -219,38 +202,41 @@ def run(self, data: dict) -> dict: round_num += 1 num_calls_remaining -= 1 - # Call MLX server - content = self.call_server(messages) - print(f"Round {round_num}: {content[:200]}..." if len(content) > 200 else f"Round {round_num}: {content}") + print(f"--- Round {round_num} ---") - # Clean up response - if '' in content: - content = content[:content.find('')] + # Generate response + content = self.generate_response(messages) - messages.append({"role": "assistant", "content": content.strip()}) + preview = content[:200] + "..." if len(content) > 200 else content + print(f"Response: {preview}") + + messages.append({"role": "assistant", "content": content}) # Check for tool calls if '' in content and '' in content: tool_call_str = content.split('')[1].split('')[0] try: - if "python" in tool_call_str.lower(): + if "python" in tool_call_str.lower() and "" in content: try: code = content.split('')[1].split('')[0] code = code.split('')[1].split('')[0].strip() - result = TOOL_MAP['PythonInterpreter'].call(code) - except: + result = str(TOOL_MAP['PythonInterpreter'].call(code)) + except Exception: result = "[Python Interpreter Error]: Formatting error." else: tool_call = json5.loads(tool_call_str) tool_name = tool_call.get('name', '') tool_args = tool_call.get('arguments', {}) + print(f"Tool call: {tool_name} with args: {tool_args}") result = self.custom_call_tool(tool_name, tool_args) except Exception as e: result = f'Error: Tool call is not valid JSON. Must contain "name" and "arguments" fields. Error: {e}' - result = f"\n{result}\n" - messages.append({"role": "user", "content": result}) + print(f"Tool result preview: {result[:200]}..." if len(result) > 200 else f"Tool result: {result}") + + tool_response = f"\n{result}\n" + messages.append({"role": "user", "content": tool_response}) # Check for final answer if '' in content and '' in content: @@ -265,20 +251,21 @@ def run(self, data: dict) -> dict: # Check token limit token_count = self.estimate_tokens(messages) - print(f"Round {round_num}, estimated tokens: {token_count}") + print(f"Estimated tokens: {token_count}") if token_count > max_context_tokens: print(f"Token limit exceeded: {token_count} > {max_context_tokens}") # Force final answer - messages[-1]['content'] = ( - "You have reached the maximum context length. Stop making tool calls and " - "provide your best answer based on all information above in this format:\n" - "your final thinking\nyour answer" - ) + messages.append({ + "role": "user", + "content": "You have reached the maximum context length. Stop making tool calls and " + "provide your best answer based on all information above in this format:\n" + "your final thinking\nyour answer" + }) - content = self.call_server(messages) - messages.append({"role": "assistant", "content": content.strip()}) + content = self.generate_response(messages) + messages.append({"role": "assistant", "content": content}) if '' in content and '' in content: prediction = content.split('')[1].split('')[0] @@ -296,11 +283,15 @@ def run(self, data: dict) -> dict: } if num_calls_remaining <= 0: - messages[-1]['content'] = "Maximum LLM calls reached." + messages.append({ + "role": "user", + "content": "Maximum LLM calls reached. Please provide your final answer now." + }) # No answer found - if '' in messages[-1]['content']: - prediction = messages[-1]['content'].split('')[1].split('')[0] + last_content = messages[-1].get('content', '') + if '' in last_content: + prediction = last_content.split('')[1].split('')[0] termination = "answer" else: prediction = "No answer found." @@ -318,23 +309,17 @@ def run(self, data: dict) -> dict: def main(): parser = argparse.ArgumentParser(description="Run DeepResearch with MLX on Apple Silicon") parser.add_argument("--model", type=str, default="abalogh/Tongyi-DeepResearch-30B-A3B-4bit", - help="Model name (should match MLX server)") + help="Model path or HuggingFace model ID") parser.add_argument("--dataset", type=str, required=True, help="Path to input dataset (JSON or JSONL)") parser.add_argument("--output", type=str, default="./outputs", help="Output directory") - parser.add_argument("--mlx_host", type=str, default="127.0.0.1", - help="MLX server host") - parser.add_argument("--mlx_port", type=int, default=8080, - help="MLX server port") parser.add_argument("--temperature", type=float, default=0.85, help="Sampling temperature") parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling") - parser.add_argument("--presence_penalty", type=float, default=1.1, - help="Presence penalty") - parser.add_argument("--max_workers", type=int, default=1, - help="Number of parallel workers (keep at 1 for MLX)") + parser.add_argument("--max_tokens", type=int, default=8192, + help="Maximum tokens per generation") parser.add_argument("--roll_out_count", type=int, default=1, help="Number of rollouts per question") args = parser.parse_args() @@ -347,12 +332,11 @@ def main(): os.makedirs(output_dir, exist_ok=True) print("=" * 50) - print("DeepResearch MLX Inference") + print("DeepResearch MLX Inference (Native API)") print("=" * 50) print(f"Model: {args.model}") print(f"Dataset: {args.dataset}") print(f"Output: {output_dir}") - print(f"MLX Server: http://{args.mlx_host}:{args.mlx_port}") print(f"Temperature: {args.temperature}") print(f"Rollouts: {args.roll_out_count}") print("=" * 50) @@ -377,20 +361,12 @@ def main(): print(f"Loaded {len(items)} items from dataset") # Initialize agent - try: - agent = MLXReactAgent( - model=args.model, - mlx_host=args.mlx_host, - mlx_port=args.mlx_port, - temperature=args.temperature, - top_p=args.top_p, - presence_penalty=args.presence_penalty, - ) - except ConnectionError as e: - print(f"Error: {e}") - print("Make sure the MLX server is running:") - print(f" mlx_lm.server --model {args.model} --port {args.mlx_port}") - return + agent = MLXReactAgent( + model_path=args.model, + temperature=args.temperature, + top_p=args.top_p, + max_tokens=args.max_tokens, + ) # Setup output files per rollout output_files = { @@ -399,9 +375,9 @@ def main(): } # Load already processed questions - processed_per_rollout = {} + processed_per_rollout: Dict[int, set] = {} for rollout_idx in range(1, args.roll_out_count + 1): - processed = set() + processed: set = set() output_file = output_files[rollout_idx] if os.path.exists(output_file): with open(output_file, "r", encoding="utf-8") as f: @@ -426,7 +402,7 @@ def main(): user_msg = item["messages"][1]["content"] question = user_msg.split("User:")[1].strip() if "User:" in user_msg else user_msg item["question"] = question - except: + except Exception: continue if question and question not in processed: @@ -442,7 +418,6 @@ def main(): return # Run tasks - # Note: MLX is single-threaded on GPU, so max_workers=1 is recommended write_locks = {i: threading.Lock() for i in range(1, args.roll_out_count + 1)} for task in tqdm(tasks, desc="Processing"): @@ -459,6 +434,8 @@ def main(): except Exception as e: print(f"Error processing task: {e}") + import traceback + traceback.print_exc() error_result = { "question": task["item"].get("question", ""), "answer": task["item"].get("answer", ""), diff --git a/inference/test_mlx_tool_loop.py b/inference/test_mlx_tool_loop.py new file mode 100644 index 0000000..e430d8e --- /dev/null +++ b/inference/test_mlx_tool_loop.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +""" +Diagnostic test for MLX tool response injection. + +This script tests the complete tool call loop: +1. Send a question to MLX +2. Model generates ... +3. We parse and execute the tool +4. We inject ... +5. Model continues with the tool response + +Usage: + python test_mlx_tool_loop.py +""" + +import os +import sys +import json +from datetime import datetime +from typing import Optional, Tuple, Dict, Any, List + +from dotenv import load_dotenv +load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env")) + +from openai import OpenAI +import json5 + +# Import tools +sys.path.insert(0, os.path.dirname(__file__)) +from tool_search import Search +from prompt import SYSTEM_PROMPT + +TOOL_MAP: Dict[str, Any] = {"search": Search()} + + +def today_date() -> str: + return datetime.now().strftime("%Y-%m-%d") + + +def parse_tool_call(content: str) -> Tuple[Optional[str], Optional[Dict[str, Any]]]: + """Extract tool name and arguments from model output.""" + if "" not in content or "" not in content: + return None, None + + tool_call_str = content.split("")[1].split("")[0].strip() + + try: + tool_call = json5.loads(tool_call_str) + name = tool_call.get("name") if isinstance(tool_call, dict) else None + args = tool_call.get("arguments", {}) if isinstance(tool_call, dict) else {} + return name, args + except Exception as e: + print(f"Failed to parse tool call JSON: {e}") + print(f"Raw tool call: {tool_call_str}") + return None, None + + +def execute_tool(name: str, args: Dict[str, Any]) -> str: + """Execute a tool and return the result.""" + if name not in TOOL_MAP: + return f"Error: Tool '{name}' not found. Available: {list(TOOL_MAP.keys())}" + + args["params"] = args + return TOOL_MAP[name].call(args) + + +def test_tool_loop(): + """Test the complete tool call loop.""" + print("=" * 60) + print("MLX Tool Response Injection Diagnostic Test") + print("=" * 60) + + # Connect to MLX server + client = OpenAI( + api_key="mlx-local", + base_url="http://127.0.0.1:8080/v1", + timeout=300.0, + ) + + # Verify connection + try: + models = client.models.list() + print(f"Connected to MLX server. Model: {models.data[0].id}") + except Exception as e: + print(f"ERROR: Cannot connect to MLX server: {e}") + print("Make sure the MLX server is running:") + print(" mlx_lm.server --model abalogh/Tongyi-DeepResearch-30B-A3B-4bit --port 8080") + return + + # Build messages + system_prompt = SYSTEM_PROMPT + str(today_date()) + question = "What are the latest developments in quantum computing in 2024?" + + messages: List[Dict[str, str]] = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question} + ] + + print(f"\nQuestion: {question}") + print("-" * 60) + + max_rounds = 5 + for round_num in range(1, max_rounds + 1): + print(f"\n--- Round {round_num} ---") + print(f"Messages count: {len(messages)}") + print(f"Last message role: {messages[-1]['role']}") + + # Call MLX + response = client.chat.completions.create( + model=models.data[0].id, + messages=messages, # type: ignore + stop=["\n", ""], + temperature=0.85, + top_p=0.95, + max_tokens=8192, + ) + + raw_content = response.choices[0].message.content + finish_reason = response.choices[0].finish_reason + content = raw_content.strip() if raw_content else "" + + print(f"\nFinish reason: {finish_reason}") + + # Clean up if leaked + if "" in content: + content = content[:content.find("")] + + print(f"\nModel output ({len(content)} chars):") + print("-" * 40) + print(content[:1000] + "..." if len(content) > 1000 else content) + print("-" * 40) + + # Add assistant message + messages.append({"role": "assistant", "content": content}) + + # Check for final answer + if "" in content and "" in content: + answer = content.split("")[1].split("")[0] + print(f"\nFINAL ANSWER: {answer}") + print(f"Total rounds: {round_num}") + break + + # Check for tool call + tool_name, tool_args = parse_tool_call(content) + + if tool_name and tool_args is not None: + print(f"\nTool call detected: {tool_name}") + print(f"Arguments: {json.dumps(tool_args, indent=2)}") + + # Execute tool + result = execute_tool(tool_name, tool_args) + print(f"\nTool result ({len(result)} chars):") + print(result[:500] + "..." if len(result) > 500 else result) + + # Inject tool response + tool_response = f"\n{result}\n" + messages.append({"role": "user", "content": tool_response}) + print(f"\nInjected tool_response as user message") + else: + print("\nNo tool call detected in output") + if round_num < max_rounds: + print("Model may be stuck - no tool call and no answer") + + # Print final message history + print("\n" + "=" * 60) + print("FULL MESSAGE HISTORY") + print("=" * 60) + for i, msg in enumerate(messages): + role = msg["role"] + content = msg["content"] + preview = content[:200] + "..." if len(content) > 200 else content + print(f"\n[{i}] {role.upper()}:") + print(preview) + + +if __name__ == "__main__": + test_tool_loop() From 07174a49923fef85704e81c98af48191704103b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Chindri=C8=99=20Mihai=20Alexandru?= <12643176+chindris-mihai-alexandru@users.noreply.github.com> Date: Thu, 27 Nov 2025 00:55:38 +0200 Subject: [PATCH 6/9] feat: improve MLX runner with graceful shutdown, proper token counting, and timeout protection --- inference/run_mlx_react.py | 320 +++++++++++++++++++++++++------------ 1 file changed, 221 insertions(+), 99 deletions(-) diff --git a/inference/run_mlx_react.py b/inference/run_mlx_react.py index 1ef0ffa..a0e17e1 100644 --- a/inference/run_mlx_react.py +++ b/inference/run_mlx_react.py @@ -4,6 +4,9 @@ This script runs DeepResearch using Apple's MLX framework instead of vLLM/CUDA. Uses native MLX Python API with proper chat template handling. +Requirements: + pip install mlx-lm python-dotenv requests json5 tqdm qwen-agent + Usage: python run_mlx_react.py --dataset eval_data/test.jsonl --output ./outputs """ @@ -11,12 +14,14 @@ import argparse import json import os +import signal +import sys import time import threading from datetime import datetime from typing import Any, Dict, List, Optional -# Load environment variables +# Load environment variables before other imports from dotenv import load_dotenv load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env")) @@ -27,7 +32,10 @@ from prompt import SYSTEM_PROMPT -# Tool registry - import tools with fallbacks for compatibility +# Disable tokenizer parallelism warning +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Tool registry - import tools with fallbacks TOOL_MAP: Dict[str, Any] = {} try: @@ -64,6 +72,23 @@ MAX_LLM_CALL_PER_RUN = int(os.getenv('MAX_LLM_CALL_PER_RUN', 100)) +# Graceful shutdown flag +shutdown_requested = False + + +def signal_handler(signum, frame): + """Handle interrupt signals gracefully.""" + global shutdown_requested + if shutdown_requested: + print("\nForce quit...") + sys.exit(1) + shutdown_requested = True + print("\nShutdown requested. Finishing current task...") + + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + def today_date() -> str: return datetime.now().strftime("%Y-%m-%d") @@ -71,9 +96,9 @@ def today_date() -> str: class MLXReactAgent: """ - React agent that uses native MLX Python API for inference on Apple Silicon. + React agent using native MLX Python API for inference on Apple Silicon. - Uses the model's chat template directly for proper tool-calling format. + Uses the model's built-in chat template for proper formatting. """ def __init__(self, model_path: str, temperature: float = 0.85, @@ -85,32 +110,57 @@ def __init__(self, model_path: str, temperature: float = 0.85, print(f"Loading model: {model_path}") self.model, self.tokenizer = load(model_path) - print("Model loaded successfully") + print(f"Model loaded successfully (memory: {self._get_memory_usage():.1f} GB)") + + def _get_memory_usage(self) -> float: + """Get current GPU memory usage in GB.""" + try: + import mlx.core as mx + # Force memory stats update + mx.metal.get_peak_memory() + return mx.metal.get_active_memory() / (1024**3) + except Exception: + return 0.0 def build_prompt(self, messages: List[Dict[str, str]]) -> str: """ - Build prompt using the Qwen chat template format. - Format: <|im_start|>role\ncontent<|im_end|> + Build prompt using tokenizer's chat template. + Falls back to manual Qwen format if template unavailable. """ + # Try using tokenizer's built-in chat template + if hasattr(self.tokenizer, 'apply_chat_template'): + try: + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + return prompt + except Exception as e: + print(f"Warning: apply_chat_template failed, using manual format: {e}") + + # Fallback: Manual Qwen/ChatML format prompt_parts = [] for msg in messages: role = msg["role"] content = msg["content"] prompt_parts.append(f"<|im_start|>{role}\n{content}<|im_end|>") - - # Add assistant start token for generation prompt_parts.append("<|im_start|>assistant\n") return "\n".join(prompt_parts) + def count_tokens(self, messages: List[Dict[str, str]]) -> int: + """Count tokens using the actual tokenizer.""" + prompt = self.build_prompt(messages) + tokens = self.tokenizer.encode(prompt) + return len(tokens) + def generate_response(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str: """Generate response using native MLX API.""" prompt = self.build_prompt(messages) tokens = max_tokens or self.max_tokens - # Create sampler with temperature and top_p sampler = make_sampler(temp=self.temperature, top_p=self.top_p) - # Generate with sampler response = generate( self.model, self.tokenizer, @@ -128,28 +178,42 @@ def generate_response(self, messages: List[Dict[str, str]], max_tokens: Optional return response.strip() - def estimate_tokens(self, messages: List[Dict[str, str]]) -> int: - """Rough token estimation.""" - total_chars = sum(len(m.get("content", "")) for m in messages) - return total_chars // 4 - - def custom_call_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> str: - """Execute a tool and return the result.""" + def execute_tool(self, tool_name: str, tool_args: Dict[str, Any], timeout: int = 120) -> str: + """Execute a tool with timeout protection.""" if tool_name not in TOOL_MAP: - return f"Error: Tool {tool_name} not found. Available: {list(TOOL_MAP.keys())}" + return f"Error: Tool '{tool_name}' not found. Available: {list(TOOL_MAP.keys())}" + # Prepare args tool_args["params"] = tool_args + result = "" + error = None - if "python" in tool_name.lower(): - return str(TOOL_MAP['PythonInterpreter'].call(tool_args)) + def run_tool(): + nonlocal result, error + try: + if "python" in tool_name.lower(): + result = str(TOOL_MAP['PythonInterpreter'].call(tool_args)) + elif tool_name == "parse_file": + import asyncio + params = {"files": tool_args.get("files", [])} + r = asyncio.run(TOOL_MAP[tool_name].call(params, file_root_path="./eval_data/file_corpus")) + result = str(r) if not isinstance(r, str) else r + else: + result = str(TOOL_MAP[tool_name].call(tool_args)) + except Exception as e: + error = str(e) + + thread = threading.Thread(target=run_tool) + thread.start() + thread.join(timeout=timeout) + + if thread.is_alive(): + return f"Error: Tool '{tool_name}' timed out after {timeout}s" - if tool_name == "parse_file": - import asyncio - params = {"files": tool_args["files"]} - result = asyncio.run(TOOL_MAP[tool_name].call(params, file_root_path="./eval_data/file_corpus")) - return str(result) if not isinstance(result, str) else result + if error: + return f"Error executing tool '{tool_name}': {error}" - return str(TOOL_MAP[tool_name].call(tool_args)) + return result def run(self, data: Dict[str, Any]) -> Dict[str, Any]: """ @@ -161,6 +225,8 @@ def run(self, data: Dict[str, Any]) -> Dict[str, Any]: Returns: Dict with question, answer, messages, prediction, and termination status """ + global shutdown_requested + # Extract question item = data['item'] question = item.get('question', '') @@ -184,10 +250,20 @@ def run(self, data: Dict[str, Any]) -> Dict[str, Any]: num_calls_remaining = MAX_LLM_CALL_PER_RUN round_num = 0 - max_context_tokens = 110 * 1024 # ~110K tokens max - timeout_minutes = 150 # 2.5 hours + max_context_tokens = 100 * 1024 # 100K tokens (conservative for 128K model) + timeout_minutes = 120 # 2 hours while num_calls_remaining > 0: + # Check for shutdown + if shutdown_requested: + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": "Interrupted by user", + "termination": "interrupted" + } + # Check timeout elapsed = time.time() - start_time if elapsed > timeout_minutes * 60: @@ -202,7 +278,7 @@ def run(self, data: Dict[str, Any]) -> Dict[str, Any]: round_num += 1 num_calls_remaining -= 1 - print(f"--- Round {round_num} ---") + print(f"--- Round {round_num} (calls left: {num_calls_remaining}) ---") # Generate response content = self.generate_response(messages) @@ -217,23 +293,23 @@ def run(self, data: Dict[str, Any]) -> Dict[str, Any]: tool_call_str = content.split('')[1].split('')[0] try: + # Handle Python interpreter specially if "python" in tool_call_str.lower() and "" in content: - try: - code = content.split('')[1].split('')[0] - code = code.split('')[1].split('')[0].strip() - result = str(TOOL_MAP['PythonInterpreter'].call(code)) - except Exception: - result = "[Python Interpreter Error]: Formatting error." + code = content.split('')[1].split('')[0].strip() + result = self.execute_tool('PythonInterpreter', {"code": code}) else: - tool_call = json5.loads(tool_call_str) + tool_call = json5.loads(tool_call_str.strip()) tool_name = tool_call.get('name', '') tool_args = tool_call.get('arguments', {}) - print(f"Tool call: {tool_name} with args: {tool_args}") - result = self.custom_call_tool(tool_name, tool_args) + print(f"Tool: {tool_name} | Args: {json.dumps(tool_args)[:100]}...") + result = self.execute_tool(tool_name, tool_args) + except json.JSONDecodeError as e: + result = f'Error: Invalid JSON in tool call. {e}' except Exception as e: - result = f'Error: Tool call is not valid JSON. Must contain "name" and "arguments" fields. Error: {e}' + result = f'Error: Tool call failed. {e}' - print(f"Tool result preview: {result[:200]}..." if len(result) > 200 else f"Tool result: {result}") + result_preview = result[:200] + "..." if len(result) > 200 else result + print(f"Result: {result_preview}") tool_response = f"\n{result}\n" messages.append({"role": "user", "content": tool_response}) @@ -241,30 +317,32 @@ def run(self, data: Dict[str, Any]) -> Dict[str, Any]: # Check for final answer if '' in content and '' in content: prediction = content.split('')[1].split('')[0] + elapsed_mins = (time.time() - start_time) / 60 + print(f"Answer found in {elapsed_mins:.1f} minutes") return { "question": question, "answer": answer, "messages": messages, - "prediction": prediction, + "prediction": prediction.strip(), "termination": "answer" } # Check token limit - token_count = self.estimate_tokens(messages) - print(f"Estimated tokens: {token_count}") + token_count = self.count_tokens(messages) + print(f"Tokens: {token_count:,}") if token_count > max_context_tokens: - print(f"Token limit exceeded: {token_count} > {max_context_tokens}") + print(f"Token limit exceeded: {token_count:,} > {max_context_tokens:,}") # Force final answer messages.append({ "role": "user", - "content": "You have reached the maximum context length. Stop making tool calls and " - "provide your best answer based on all information above in this format:\n" - "your final thinking\nyour answer" + "content": "IMPORTANT: You have reached the maximum context length. " + "Stop making tool calls. Provide your final answer NOW based on all information above.\n" + "Format: final reasoning\nyour answer" }) - content = self.generate_response(messages) + content = self.generate_response(messages, max_tokens=2048) messages.append({"role": "assistant", "content": content}) if '' in content and '' in content: @@ -272,42 +350,48 @@ def run(self, data: Dict[str, Any]) -> Dict[str, Any]: termination = "token_limit_answer" else: prediction = content - termination = "token_limit_format_error" + termination = "token_limit_no_answer" return { "question": question, "answer": answer, "messages": messages, - "prediction": prediction, + "prediction": prediction.strip(), "termination": termination } - - if num_calls_remaining <= 0: - messages.append({ - "role": "user", - "content": "Maximum LLM calls reached. Please provide your final answer now." - }) - # No answer found - last_content = messages[-1].get('content', '') - if '' in last_content: - prediction = last_content.split('')[1].split('')[0] - termination = "answer" + # Max calls reached - try to get final answer + print("Max LLM calls reached, requesting final answer...") + messages.append({ + "role": "user", + "content": "Maximum iterations reached. Provide your final answer NOW.\n" + "your answer" + }) + + content = self.generate_response(messages, max_tokens=2048) + messages.append({"role": "assistant", "content": content}) + + if '' in content and '' in content: + prediction = content.split('')[1].split('')[0] + termination = "max_calls_answer" else: - prediction = "No answer found." - termination = "calls_exhausted" + prediction = content if content else "No answer found." + termination = "max_calls_no_answer" return { "question": question, "answer": answer, "messages": messages, - "prediction": prediction, + "prediction": prediction.strip(), "termination": termination } def main(): - parser = argparse.ArgumentParser(description="Run DeepResearch with MLX on Apple Silicon") + parser = argparse.ArgumentParser( + description="Run DeepResearch with MLX on Apple Silicon", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument("--model", type=str, default="abalogh/Tongyi-DeepResearch-30B-A3B-4bit", help="Model path or HuggingFace model ID") parser.add_argument("--dataset", type=str, required=True, @@ -315,15 +399,21 @@ def main(): parser.add_argument("--output", type=str, default="./outputs", help="Output directory") parser.add_argument("--temperature", type=float, default=0.85, - help="Sampling temperature") + help="Sampling temperature (0.0-2.0)") parser.add_argument("--top_p", type=float, default=0.95, - help="Top-p sampling") + help="Top-p (nucleus) sampling (0.0-1.0)") parser.add_argument("--max_tokens", type=int, default=8192, help="Maximum tokens per generation") parser.add_argument("--roll_out_count", type=int, default=1, help="Number of rollouts per question") args = parser.parse_args() + # Validate args + if not 0.0 <= args.temperature <= 2.0: + print("Warning: temperature should be between 0.0 and 2.0") + if not 0.0 <= args.top_p <= 1.0: + print("Warning: top_p should be between 0.0 and 1.0") + # Setup output directory model_name = os.path.basename(args.model.rstrip('/')) model_dir = os.path.join(args.output, f"{model_name}_mlx") @@ -331,42 +421,55 @@ def main(): output_dir = os.path.join(model_dir, dataset_name) os.makedirs(output_dir, exist_ok=True) - print("=" * 50) - print("DeepResearch MLX Inference (Native API)") - print("=" * 50) - print(f"Model: {args.model}") - print(f"Dataset: {args.dataset}") - print(f"Output: {output_dir}") + print("=" * 60) + print("DeepResearch MLX Inference (Apple Silicon)") + print("=" * 60) + print(f"Model: {args.model}") + print(f"Dataset: {args.dataset}") + print(f"Output: {output_dir}") print(f"Temperature: {args.temperature}") - print(f"Rollouts: {args.roll_out_count}") - print("=" * 50) + print(f"Top-P: {args.top_p}") + print(f"Max Tokens: {args.max_tokens}") + print(f"Rollouts: {args.roll_out_count}") + print("=" * 60) # Load dataset try: if args.dataset.endswith(".json"): with open(args.dataset, "r", encoding="utf-8") as f: items = json.load(f) + if isinstance(items, dict): + items = [items] elif args.dataset.endswith(".jsonl"): with open(args.dataset, "r", encoding="utf-8") as f: - items = [json.loads(line) for line in f] + items = [json.loads(line) for line in f if line.strip()] else: - raise ValueError("Dataset must be .json or .jsonl") + print("Error: Dataset must be .json or .jsonl") + return 1 except FileNotFoundError: print(f"Error: Dataset not found at {args.dataset}") - return - except Exception as e: - print(f"Error loading dataset: {e}") - return + return 1 + except json.JSONDecodeError as e: + print(f"Error: Invalid JSON in dataset: {e}") + return 1 print(f"Loaded {len(items)} items from dataset") + if not items: + print("Error: No items in dataset") + return 1 + # Initialize agent - agent = MLXReactAgent( - model_path=args.model, - temperature=args.temperature, - top_p=args.top_p, - max_tokens=args.max_tokens, - ) + try: + agent = MLXReactAgent( + model_path=args.model, + temperature=args.temperature, + top_p=args.top_p, + max_tokens=args.max_tokens, + ) + except Exception as e: + print(f"Error loading model: {e}") + return 1 # Setup output files per rollout output_files = { @@ -389,7 +492,8 @@ def main(): except json.JSONDecodeError: pass processed_per_rollout[rollout_idx] = processed - print(f"Rollout {rollout_idx}: {len(processed)} already processed") + if processed: + print(f"Rollout {rollout_idx}: {len(processed)} already processed") # Build task list tasks = [] @@ -415,27 +519,38 @@ def main(): if not tasks: print("All tasks already completed!") - return + return 0 # Run tasks - write_locks = {i: threading.Lock() for i in range(1, args.roll_out_count + 1)} + write_lock = threading.Lock() + completed = 0 + failed = 0 - for task in tqdm(tasks, desc="Processing"): + for task in tqdm(tasks, desc="Processing", disable=shutdown_requested): + if shutdown_requested: + print(f"\nStopped early. Completed: {completed}, Failed: {failed}") + break + rollout_idx = task["rollout_idx"] output_file = output_files[rollout_idx] try: result = agent.run(task) result["rollout_idx"] = rollout_idx + result["elapsed_time"] = time.time() - with write_locks[rollout_idx]: + with write_lock: with open(output_file, "a", encoding="utf-8") as f: f.write(json.dumps(result, ensure_ascii=False) + "\n") - + + completed += 1 + except Exception as e: - print(f"Error processing task: {e}") + failed += 1 + print(f"\nError: {e}") import traceback traceback.print_exc() + error_result = { "question": task["item"].get("question", ""), "answer": task["item"].get("answer", ""), @@ -444,13 +559,20 @@ def main(): "messages": [], "prediction": "[Failed]" } - with write_locks[rollout_idx]: + with write_lock: with open(output_file, "a", encoding="utf-8") as f: f.write(json.dumps(error_result, ensure_ascii=False) + "\n") - print("\nInference complete!") - print(f"Results saved to: {output_dir}") + print("\n" + "=" * 60) + print("Inference Complete") + print("=" * 60) + print(f"Completed: {completed}") + print(f"Failed: {failed}") + print(f"Output: {output_dir}") + print("=" * 60) + + return 0 if failed == 0 else 1 if __name__ == "__main__": - main() + sys.exit(main()) From 2dad6b66e26ac71257ab8dc4bc32ef81e4b64dff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Chindri=C8=99=20Mihai=20Alexandru?= <12643176+chindris-mihai-alexandru@users.noreply.github.com> Date: Thu, 27 Nov 2025 17:48:49 +0200 Subject: [PATCH 7/9] feat: improve agent behavior with loop detection, error handling, and visit fallbacks - Add loop detection to break infinite tool call cycles - Track consecutive errors and force answer after 3 failures - Inject reminder at round 5 to encourage timely conclusions - Rewrite visit tool with raw content fallback when summarization unavailable - Add explicit answer behavior guidelines to system prompt - Create interactive CLI (interactive.py) for normal usage - Simplify tool descriptions for clarity --- inference/interactive.py | 214 ++++++++++++++++++++ inference/prompt.py | 41 ++-- inference/run_mlx_react.py | 57 ++++++ inference/tool_visit.py | 396 +++++++++++++++++++------------------ 4 files changed, 494 insertions(+), 214 deletions(-) create mode 100644 inference/interactive.py diff --git a/inference/interactive.py b/inference/interactive.py new file mode 100644 index 0000000..d880f1e --- /dev/null +++ b/inference/interactive.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +""" +Interactive CLI for DeepResearch on Apple Silicon (MLX) + +Usage: + python interactive.py [--model MODEL_PATH] + +Example: + python interactive.py + python interactive.py --model abalogh/Tongyi-DeepResearch-30B-A3B-4bit +""" + +import argparse +import json +import os +import sys +import time + +# Load environment variables first +from dotenv import load_dotenv +load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env")) + +# Disable tokenizer parallelism warning +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Optional: rich for better formatting +try: + from rich.console import Console + from rich.markdown import Markdown + from rich.panel import Panel + from rich.progress import Progress, SpinnerColumn, TextColumn + RICH_AVAILABLE = True + console = Console() +except ImportError: + RICH_AVAILABLE = False + console = None + + +def print_header(): + """Print welcome header.""" + header = """ +╔══════════════════════════════════════════════════════════════╗ +║ DeepResearch - Interactive Mode (MLX) ║ +║ Apple Silicon Optimized ║ +╚══════════════════════════════════════════════════════════════╝ +""" + if RICH_AVAILABLE: + console.print(header, style="bold blue") + else: + print(header) + + +def print_help(): + """Print help information.""" + help_text = """ +Commands: + /help - Show this help message + /quit - Exit the program (or Ctrl+C) + /clear - Clear conversation history (start fresh) + /status - Show model and memory status + +Just type your research question to begin! + +Examples: + > What is the current population of Tokyo? + > Who won the 2024 Nobel Prize in Physics? + > Explain the mechanism of CRISPR-Cas9 gene editing +""" + if RICH_AVAILABLE: + console.print(Panel(help_text, title="Help", border_style="green")) + else: + print(help_text) + + +def format_answer(answer: str): + """Format the answer for display.""" + if RICH_AVAILABLE: + console.print("\n") + console.print(Panel(Markdown(answer), title="[bold green]Answer[/]", border_style="green")) + else: + print("\n" + "=" * 60) + print("ANSWER:") + print("=" * 60) + print(answer) + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser(description="Interactive DeepResearch CLI") + parser.add_argument("--model", type=str, + default="abalogh/Tongyi-DeepResearch-30B-A3B-4bit", + help="Model path or HuggingFace ID") + parser.add_argument("--temperature", type=float, default=0.7, + help="Sampling temperature") + parser.add_argument("--max_tokens", type=int, default=4096, + help="Max tokens per generation") + parser.add_argument("--max_rounds", type=int, default=15, + help="Max research rounds per question") + args = parser.parse_args() + + print_header() + + # Set max rounds via environment + os.environ['MAX_LLM_CALL_PER_RUN'] = str(args.max_rounds) + + # Import agent after setting environment + print("Loading model (this may take a minute)...") + + try: + from run_mlx_react import MLXReactAgent, TOOL_MAP + except ImportError as e: + print(f"Error importing agent: {e}") + print("Make sure you're running from the inference directory.") + return 1 + + if RICH_AVAILABLE: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console + ) as progress: + progress.add_task("Loading MLX model...", total=None) + agent = MLXReactAgent( + model_path=args.model, + temperature=args.temperature, + max_tokens=args.max_tokens + ) + else: + agent = MLXReactAgent( + model_path=args.model, + temperature=args.temperature, + max_tokens=args.max_tokens + ) + + print(f"\nTools available: {list(TOOL_MAP.keys())}") + print(f"Max rounds per question: {args.max_rounds}") + print_help() + + while True: + try: + # Get user input + if RICH_AVAILABLE: + query = console.input("\n[bold cyan]Research Query>[/] ").strip() + else: + query = input("\nResearch Query> ").strip() + + # Handle commands + if not query: + continue + + if query.lower() in ('/quit', '/exit', '/q'): + print("Goodbye!") + break + + if query.lower() == '/help': + print_help() + continue + + if query.lower() == '/clear': + print("Ready for a new question.") + continue + + if query.lower() == '/status': + try: + import mlx.core as mx + mem_gb = mx.metal.get_active_memory() / (1024**3) + print(f"Model: {args.model}") + print(f"GPU Memory: {mem_gb:.1f} GB") + except Exception: + print(f"Model: {args.model}") + continue + + if query.startswith('/'): + print(f"Unknown command: {query}. Type /help for available commands.") + continue + + # Run research + print("\nResearching...\n") + start = time.time() + + data = {'item': {'question': query, 'answer': ''}} + result = agent.run(data) + + elapsed = time.time() - start + + # Display result + prediction = result.get('prediction', 'No answer found.') + termination = result.get('termination', 'unknown') + num_rounds = len([m for m in result.get('messages', []) if m.get('role') == 'assistant']) + + format_answer(prediction) + + if RICH_AVAILABLE: + console.print(f"[dim]Completed in {elapsed:.1f}s | {num_rounds} rounds | Termination: {termination}[/]") + else: + print(f"\nCompleted in {elapsed:.1f}s | {num_rounds} rounds | Termination: {termination}") + + except KeyboardInterrupt: + print("\n\nInterrupted. Type /quit to exit or continue with a new question.") + continue + except EOFError: + print("\nGoodbye!") + break + except Exception as e: + print(f"\nError: {e}") + import traceback + traceback.print_exc() + continue + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/inference/prompt.py b/inference/prompt.py index 649e1e0..d63988b 100644 --- a/inference/prompt.py +++ b/inference/prompt.py @@ -1,4 +1,19 @@ -SYSTEM_PROMPT = """You are a deep research assistant. Your core function is to conduct thorough, multi-source investigations into any topic. You must handle both broad, open-domain inquiries and queries within specialized academic fields. For every request, synthesize information from credible, diverse sources to deliver a comprehensive, accurate, and objective response. When you have gathered sufficient information and are ready to provide the definitive response, you must enclose the entire final answer within tags. +SYSTEM_PROMPT = """You are a deep research assistant. Your core function is to conduct thorough, multi-source investigations into any topic. You must handle both broad, open-domain inquiries and queries within specialized academic fields. For every request, synthesize information from credible, diverse sources to deliver a comprehensive, accurate, and objective response. + +# CRITICAL: Answer Behavior + +**You MUST provide a final answer after gathering sufficient information.** Do not continue researching indefinitely. + +Guidelines for when to provide your answer: +1. After 2-3 search queries that return relevant results, you likely have enough information +2. If multiple sources agree on key facts, you have sufficient confirmation +3. If a webpage visit fails, use the search snippets you already have +4. A good answer with available information is better than endless searching +5. When uncertain, provide the best answer you can with appropriate caveats + +**When ready to answer, use this format:** +Final reasoning about the gathered information +Your comprehensive answer here # Tools @@ -6,25 +21,11 @@ You are provided with function signatures within XML tags: -{"type": "function", "function": {"name": "search", "description": "Perform Google web searches then returns a string of the top search results. Accepts multiple queries.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string", "description": "The search query."}, "minItems": 1, "description": "The list of search queries."}}, "required": ["query"]}}} -{"type": "function", "function": {"name": "visit", "description": "Visit webpage(s) and return the summary of the content.", "parameters": {"type": "object", "properties": {"url": {"type": "array", "items": {"type": "string"}, "description": "The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs."}, "goal": {"type": "string", "description": "The specific information goal for visiting webpage(s)."}}, "required": ["url", "goal"]}}} -{"type": "function", "function": {"name": "PythonInterpreter", "description": "Executes Python code in a sandboxed environment. To use this tool, you must follow this format: -1. The 'arguments' JSON object must be empty: {}. -2. The Python code to be executed must be placed immediately after the JSON block, enclosed within and tags. - -IMPORTANT: Any output you want to see MUST be printed to standard output using the print() function. - -Example of a correct call: - -{"name": "PythonInterpreter", "arguments": {}} - -import numpy as np -# Your code here -print(f"The result is: {np.mean([1,2,3])}") - -", "parameters": {"type": "object", "properties": {}, "required": []}}} -{"type": "function", "function": {"name": "google_scholar", "description": "Leverage Google Scholar to retrieve relevant information from academic publications. Accepts multiple queries. This tool will also return results from google search", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string", "description": "The search query."}, "minItems": 1, "description": "The list of search queries for Google Scholar."}}, "required": ["query"]}}} -{"type": "function", "function": {"name": "parse_file", "description": "This is a tool that can be used to parse multiple user uploaded local files such as PDF, DOCX, PPTX, TXT, CSV, XLSX, DOC, ZIP, MP4, MP3.", "parameters": {"type": "object", "properties": {"files": {"type": "array", "items": {"type": "string"}, "description": "The file name of the user uploaded local files to be parsed."}}, "required": ["files"]}}} +{"type": "function", "function": {"name": "search", "description": "Perform web searches and return top results with snippets. Use this first to find relevant sources.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string"}, "minItems": 1, "description": "Search queries (1-3 queries recommended)."}}, "required": ["query"]}}} +{"type": "function", "function": {"name": "visit", "description": "Visit webpage(s) to extract detailed content. Only visit if search snippets are insufficient.", "parameters": {"type": "object", "properties": {"url": {"type": "array", "items": {"type": "string"}, "description": "URL(s) to visit."}, "goal": {"type": "string", "description": "What specific information you need from the page."}}, "required": ["url", "goal"]}}} +{"type": "function", "function": {"name": "google_scholar", "description": "Search academic publications. Use for scientific/research questions.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string"}, "minItems": 1, "description": "Academic search queries."}}, "required": ["query"]}}} +{"type": "function", "function": {"name": "PythonInterpreter", "description": "Execute Python code for calculations or data processing.", "parameters": {"type": "object", "properties": {}, "required": []}}} +{"type": "function", "function": {"name": "parse_file", "description": "Parse uploaded files (PDF, DOCX, etc.).", "parameters": {"type": "object", "properties": {"files": {"type": "array", "items": {"type": "string"}, "description": "File names to parse."}}, "required": ["files"]}}} For each function call, return a json object with function name and arguments within XML tags: diff --git a/inference/run_mlx_react.py b/inference/run_mlx_react.py index a0e17e1..6cd2db7 100644 --- a/inference/run_mlx_react.py +++ b/inference/run_mlx_react.py @@ -252,6 +252,8 @@ def run(self, data: Dict[str, Any]) -> Dict[str, Any]: round_num = 0 max_context_tokens = 100 * 1024 # 100K tokens (conservative for 128K model) timeout_minutes = 120 # 2 hours + consecutive_errors = 0 + last_tool_call = "" # For loop detection while num_calls_remaining > 0: # Check for shutdown @@ -280,6 +282,13 @@ def run(self, data: Dict[str, Any]) -> Dict[str, Any]: print(f"--- Round {round_num} (calls left: {num_calls_remaining}) ---") + # Inject reminder at round 5 to encourage conclusion + if round_num == 5: + messages.append({ + "role": "user", + "content": "REMINDER: You have made several searches. If you have enough information to answer the question, please provide your final answer now using tags. Only continue searching if absolutely necessary." + }) + # Generate response content = self.generate_response(messages) @@ -292,6 +301,29 @@ def run(self, data: Dict[str, Any]) -> Dict[str, Any]: if '' in content and '' in content: tool_call_str = content.split('')[1].split('')[0] + # Loop detection: check if same tool call as last time + if tool_call_str.strip() == last_tool_call: + print("Warning: Detected repeated tool call, forcing answer...") + messages.append({ + "role": "user", + "content": "You are repeating the same action. Stop and provide your final answer NOW based on available information.\nyour answer" + }) + content = self.generate_response(messages, max_tokens=2048) + messages.append({"role": "assistant", "content": content}) + if '' in content and '' in content: + prediction = content.split('')[1].split('')[0] + else: + prediction = content + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction.strip(), + "termination": "loop_detected" + } + + last_tool_call = tool_call_str.strip() + try: # Handle Python interpreter specially if "python" in tool_call_str.lower() and "" in content: @@ -308,6 +340,31 @@ def run(self, data: Dict[str, Any]) -> Dict[str, Any]: except Exception as e: result = f'Error: Tool call failed. {e}' + # Track consecutive errors + if result.startswith('Error:'): + consecutive_errors += 1 + if consecutive_errors >= 3: + print(f"Warning: {consecutive_errors} consecutive errors, forcing answer...") + messages.append({ + "role": "user", + "content": f"Multiple tool errors occurred. Please provide your best answer based on the information you have gathered so far.\nyour answer" + }) + content = self.generate_response(messages, max_tokens=2048) + messages.append({"role": "assistant", "content": content}) + if '' in content and '' in content: + prediction = content.split('')[1].split('')[0] + else: + prediction = content + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction.strip(), + "termination": "consecutive_errors" + } + else: + consecutive_errors = 0 # Reset on success + result_preview = result[:200] + "..." if len(result) > 200 else result print(f"Result: {result_preview}") diff --git a/inference/tool_visit.py b/inference/tool_visit.py index 92e4e3a..26f6e25 100644 --- a/inference/tool_visit.py +++ b/inference/tool_visit.py @@ -1,256 +1,264 @@ import json import os -import signal -import threading +import re from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import List, Union +from typing import List, Union, Optional import requests from qwen_agent.tools.base import BaseTool, register_tool from prompt import EXTRACTOR_PROMPT from openai import OpenAI -import random -from urllib.parse import urlparse, unquote +from urllib.parse import urlparse import time -from transformers import AutoTokenizer import tiktoken VISIT_SERVER_TIMEOUT = int(os.getenv("VISIT_SERVER_TIMEOUT", 200)) WEBCONTENT_MAXLENGTH = int(os.getenv("WEBCONTENT_MAXLENGTH", 150000)) - JINA_API_KEYS = os.getenv("JINA_API_KEYS", "") +# Maximum content length to return when summarization fails +RAW_CONTENT_MAX_CHARS = int(os.getenv("RAW_CONTENT_MAX_CHARS", 8000)) + -@staticmethod def truncate_to_tokens(text: str, max_tokens: int = 95000) -> str: - encoding = tiktoken.get_encoding("cl100k_base") + """Truncate text to a maximum number of tokens.""" + try: + encoding = tiktoken.get_encoding("cl100k_base") + tokens = encoding.encode(text) + if len(tokens) <= max_tokens: + return text + return encoding.decode(tokens[:max_tokens]) + except Exception: + # Fallback: rough char estimate (4 chars per token) + max_chars = max_tokens * 4 + return text[:max_chars] if len(text) > max_chars else text + + +def extract_main_content(html_text: str, max_chars: int = 8000) -> str: + """ + Extract main content from HTML/markdown text. + Removes boilerplate like navigation, footers, etc. + """ + lines = html_text.split('\n') + content_lines = [] + total_chars = 0 - tokens = encoding.encode(text) - if len(tokens) <= max_tokens: - return text + skip_patterns = [ + r'^#{1,2}\s*(navigation|menu|footer|sidebar|cookie|privacy|terms)', + r'^\s*(\||---)', # Table separators + r'^\s*\[.*\]\(.*\)\s*$', # Standalone links + r'^(copyright|©|\d{4}\s*[-–]\s*\d{4})', + ] - truncated_tokens = tokens[:max_tokens] - return encoding.decode(truncated_tokens) - -OSS_JSON_FORMAT = """# Response Formats -## visit_content -{"properties":{"rational":{"type":"string","description":"Locate the **specific sections/data** directly related to the user's goal within the webpage content"},"evidence":{"type":"string","description":"Identify and extract the **most relevant information** from the content, never miss any important information, output the **full original context** of the content as far as possible, it can be more than three paragraphs.","summary":{"type":"string","description":"Organize into a concise paragraph with logical flow, prioritizing clarity and judge the contribution of the information to the goal."}}}}""" + for line in lines: + line_lower = line.lower().strip() + + # Skip empty lines at start + if not content_lines and not line.strip(): + continue + + # Skip navigation/boilerplate patterns + skip = False + for pattern in skip_patterns: + if re.match(pattern, line_lower): + skip = True + break + + if skip: + continue + + content_lines.append(line) + total_chars += len(line) + 1 + + if total_chars >= max_chars: + break + + return '\n'.join(content_lines) @register_tool('visit', allow_overwrite=True) class Visit(BaseTool): - # The `description` tells the agent the functionality of this tool. name = 'visit' description = 'Visit webpage(s) and return the summary of the content.' - # The `parameters` tell the agent what input parameters the tool has. parameters = { "type": "object", "properties": { "url": { "type": ["string", "array"], - "items": { - "type": "string" - }, + "items": {"type": "string"}, "minItems": 1, - "description": "The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs." - }, - "goal": { + "description": "The URL(s) of the webpage(s) to visit." + }, + "goal": { "type": "string", "description": "The goal of the visit for webpage(s)." - } + } }, "required": ["url", "goal"] } - # The `call` method is the main function of the tool. + def call(self, params: Union[str, dict], **kwargs) -> str: try: url = params["url"] goal = params["goal"] - except: - return "[Visit] Invalid request format: Input must be a JSON object containing 'url' and 'goal' fields" - - start_time = time.time() - - # Create log folder if it doesn't exist - log_folder = "log" - os.makedirs(log_folder, exist_ok=True) + except Exception: + return "[Visit] Invalid request: need 'url' and 'goal' fields" if isinstance(url, str): - response = self.readpage_jina(url, goal) - else: - response = [] - assert isinstance(url, List) - start_time = time.time() - for u in url: - if time.time() - start_time > 900: - cur_response = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal) - cur_response += "Evidence in page: \n" + "The provided webpage content could not be accessed. Please check the URL or file format." + "\n\n" - cur_response += "Summary: \n" + "The webpage content could not be processed, and therefore, no information is available." + "\n\n" - else: - try: - cur_response = self.readpage_jina(u, goal) - except Exception as e: - cur_response = f"Error fetching {u}: {str(e)}" - response.append(cur_response) - response = "\n=======\n".join(response) - - print(f'Summary Length {len(response)}; Summary Content {response}') - return response.strip() + return self.readpage(url, goal) - def call_server(self, msgs, max_retries=2): - api_key = os.environ.get("API_KEY") - url_llm = os.environ.get("API_BASE") - model_name = os.environ.get("SUMMARY_MODEL_NAME", "") - client = OpenAI( - api_key=api_key, - base_url=url_llm, - ) - for attempt in range(max_retries): + # Multiple URLs + responses = [] + start = time.time() + for u in url: + if time.time() - start > 300: # 5 min timeout for batch + responses.append(f"[Timeout] Skipped: {u}") + continue try: - chat_response = client.chat.completions.create( - model=model_name, - messages=msgs, - temperature=0.7 - ) - content = chat_response.choices[0].message.content - if content: - try: - json.loads(content) - except: - # extract json from string - left = content.find('{') - right = content.rfind('}') - if left != -1 and right != -1 and left <= right: - content = content[left:right+1] - return content + responses.append(self.readpage(u, goal)) except Exception as e: - # print(e) - if attempt == (max_retries - 1): - return "" - continue - - - def jina_readpage(self, url: str) -> str: - """ - Read webpage content using Jina service. + responses.append(f"[Error] {u}: {e}") - Args: - url: The URL to read - goal: The goal/purpose of reading the page - - Returns: - str: The webpage content or error message - """ - max_retries = 3 - timeout = 50 + return "\n\n---\n\n".join(responses) + + def jina_fetch(self, url: str, timeout: int = 30) -> Optional[str]: + """Fetch webpage content using Jina Reader API.""" + headers = {} + if JINA_API_KEYS: + headers["Authorization"] = f"Bearer {JINA_API_KEYS}" - for attempt in range(max_retries): - headers = { - "Authorization": f"Bearer {JINA_API_KEYS}", - } + for attempt in range(3): try: - response = requests.get( + resp = requests.get( f"https://r.jina.ai/{url}", headers=headers, timeout=timeout ) - if response.status_code == 200: - webpage_content = response.text - return webpage_content - else: - print(response.text) - raise ValueError("jina readpage error") - except Exception as e: - time.sleep(0.5) - if attempt == max_retries - 1: - return "[visit] Failed to read page." - - return "[visit] Failed to read page." + if resp.status_code == 200 and len(resp.text) > 100: + return resp.text + except requests.RequestException: + pass + time.sleep(0.5) + + return None - def html_readpage_jina(self, url: str) -> str: - max_attempts = 8 - for attempt in range(max_attempts): - content = self.jina_readpage(url) - service = "jina" - print(service) - if content and not content.startswith("[visit] Failed to read page.") and content != "[visit] Empty content." and not content.startswith("[document_parser]"): - return content - return "[visit] Failed to read page." + def direct_fetch(self, url: str, timeout: int = 20) -> Optional[str]: + """Fallback: fetch directly with requests.""" + headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36" + } + try: + resp = requests.get(url, headers=headers, timeout=timeout) + resp.raise_for_status() + + # Basic HTML to text conversion + text = resp.text + # Remove script/style tags + text = re.sub(r']*>.*?', '', text, flags=re.DOTALL | re.IGNORECASE) + text = re.sub(r']*>.*?', '', text, flags=re.DOTALL | re.IGNORECASE) + # Remove HTML tags + text = re.sub(r'<[^>]+>', ' ', text) + # Clean whitespace + text = re.sub(r'\s+', ' ', text).strip() + + return text if len(text) > 100 else None + except Exception: + return None - def readpage_jina(self, url: str, goal: str) -> str: - """ - Attempt to read webpage content by alternating between jina and aidata services. + def summarize_content(self, content: str, goal: str) -> Optional[dict]: + """Use LLM API to summarize content. Returns None if unavailable.""" + api_key = os.environ.get("API_KEY") + api_base = os.environ.get("API_BASE") + model = os.environ.get("SUMMARY_MODEL_NAME", "") - Args: - url: The URL to read - goal: The goal/purpose of reading the page + if not api_key or not api_base: + return None + + try: + client = OpenAI(api_key=api_key, base_url=api_base) - Returns: - str: The webpage content or error message + # Truncate content for summarization + content = truncate_to_tokens(content, max_tokens=30000) + + prompt = EXTRACTOR_PROMPT.format(webpage_content=content, goal=goal) + + resp = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + temperature=0.7, + max_tokens=2000 + ) + + result = resp.choices[0].message.content + if not result: + return None + + # Parse JSON response + result = result.replace("```json", "").replace("```", "").strip() + + # Try to extract JSON + left = result.find('{') + right = result.rfind('}') + if left != -1 and right > left: + result = result[left:right+1] + + return json.loads(result) + except Exception as e: + print(f"[visit] Summarization failed: {e}") + return None + + def readpage(self, url: str, goal: str) -> str: """ - - summary_page_func = self.call_server - max_retries = int(os.getenv('VISIT_SERVER_MAX_RETRIES', 1)) + Read and process a webpage. + + Strategy: + 1. Try Jina Reader first (best for complex pages) + 2. Fallback to direct fetch if Jina fails + 3. Try LLM summarization if API available + 4. Return extracted raw content if summarization unavailable + """ + # Step 1: Fetch content + content = self.jina_fetch(url) + + if not content: + content = self.direct_fetch(url) + + if not content: + return self._format_error(url, goal, "Failed to fetch webpage content") + + # Step 2: Try summarization + summary = self.summarize_content(content, goal) + + if summary and summary.get("evidence") and summary.get("summary"): + return self._format_success(url, goal, summary["evidence"], summary["summary"]) + + # Step 3: Fallback - return extracted raw content + extracted = extract_main_content(content, RAW_CONTENT_MAX_CHARS) + + if len(extracted) < 100: + return self._format_error(url, goal, "Page content too short or empty") + + return self._format_raw(url, goal, extracted) - content = self.html_readpage_jina(url) + def _format_success(self, url: str, goal: str, evidence: str, summary: str) -> str: + return f"""Content from {url} for goal: {goal} - if content and not content.startswith("[visit] Failed to read page.") and content != "[visit] Empty content." and not content.startswith("[document_parser]"): - content = truncate_to_tokens(content, max_tokens=95000) - messages = [{"role":"user","content": EXTRACTOR_PROMPT.format(webpage_content=content, goal=goal)}] - parse_retry_times = 0 - raw = summary_page_func(messages, max_retries=max_retries) - summary_retries = 3 - while len(raw) < 10 and summary_retries >= 0: - truncate_length = int(0.7 * len(content)) if summary_retries > 0 else 25000 - status_msg = ( - f"[visit] Summary url[{url}] " - f"attempt {3 - summary_retries + 1}/3, " - f"content length: {len(content)}, " - f"truncating to {truncate_length} chars" - ) if summary_retries > 0 else ( - f"[visit] Summary url[{url}] failed after 3 attempts, " - f"final truncation to 25000 chars" - ) - print(status_msg) - content = content[:truncate_length] - extraction_prompt = EXTRACTOR_PROMPT.format( - webpage_content=content, - goal=goal - ) - messages = [{"role": "user", "content": extraction_prompt}] - raw = summary_page_func(messages, max_retries=max_retries) - summary_retries -= 1 +**Evidence:** +{evidence} - parse_retry_times = 2 - if isinstance(raw, str): - raw = raw.replace("```json", "").replace("```", "").strip() - while parse_retry_times < 3: - try: - raw = json.loads(raw) - break - except: - raw = summary_page_func(messages, max_retries=max_retries) - parse_retry_times += 1 - - if parse_retry_times >= 3: - useful_information = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal) - useful_information += "Evidence in page: \n" + "The provided webpage content could not be accessed. Please check the URL or file format." + "\n\n" - useful_information += "Summary: \n" + "The webpage content could not be processed, and therefore, no information is available." + "\n\n" - else: - useful_information = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal) - useful_information += "Evidence in page: \n" + str(raw["evidence"]) + "\n\n" - useful_information += "Summary: \n" + str(raw["summary"]) + "\n\n" +**Summary:** +{summary}""" - if len(useful_information) < 10 and summary_retries < 0: - print("[visit] Could not generate valid summary after maximum retries") - useful_information = "[visit] Failed to read page" - - return useful_information + def _format_raw(self, url: str, goal: str, content: str) -> str: + return f"""Content from {url} for goal: {goal} + +**Raw Content (summarization unavailable):** +{content} + +Note: Please extract the relevant information for your goal from the content above.""" - # If no valid content was obtained after all retries - else: - useful_information = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal) - useful_information += "Evidence in page: \n" + "The provided webpage content could not be accessed. Please check the URL or file format." + "\n\n" - useful_information += "Summary: \n" + "The webpage content could not be processed, and therefore, no information is available." + "\n\n" - return useful_information + def _format_error(self, url: str, goal: str, reason: str) -> str: + return f"""Could not retrieve content from {url} +Goal: {goal} +Reason: {reason} - \ No newline at end of file +Please try a different source or search query.""" From fca16c2fe5098d4802a99c42377fb9c90ae1de28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Chindri=C8=99=20Mihai=20Alexandru?= <12643176+chindris-mihai-alexandru@users.noreply.github.com> Date: Thu, 27 Nov 2025 17:54:06 +0200 Subject: [PATCH 8/9] fix: use new MLX memory API to avoid deprecation warnings --- inference/interactive.py | 6 +++++- inference/run_mlx_react.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/inference/interactive.py b/inference/interactive.py index d880f1e..38dd560 100644 --- a/inference/interactive.py +++ b/inference/interactive.py @@ -163,7 +163,11 @@ def main(): if query.lower() == '/status': try: import mlx.core as mx - mem_gb = mx.metal.get_active_memory() / (1024**3) + # Use new API (mlx >= 0.24) or fall back to deprecated + if hasattr(mx, 'get_active_memory'): + mem_gb = mx.get_active_memory() / (1024**3) + else: + mem_gb = mx.metal.get_active_memory() / (1024**3) print(f"Model: {args.model}") print(f"GPU Memory: {mem_gb:.1f} GB") except Exception: diff --git a/inference/run_mlx_react.py b/inference/run_mlx_react.py index 6cd2db7..9b87323 100644 --- a/inference/run_mlx_react.py +++ b/inference/run_mlx_react.py @@ -116,8 +116,9 @@ def _get_memory_usage(self) -> float: """Get current GPU memory usage in GB.""" try: import mlx.core as mx - # Force memory stats update - mx.metal.get_peak_memory() + # Use new API (mlx >= 0.24) or fall back to deprecated + if hasattr(mx, 'get_active_memory'): + return mx.get_active_memory() / (1024**3) return mx.metal.get_active_memory() / (1024**3) except Exception: return 0.0 From bff6b816789f11f38df92212c8cf664ad3d1467c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Chindri=C8=99=20Mihai=20Alexandru?= <12643176+chindris-mihai-alexandru@users.noreply.github.com> Date: Thu, 27 Nov 2025 17:56:20 +0200 Subject: [PATCH 9/9] fix: prevent tool_args mutation and add URL validation - Copy tool_args before passing to tools to prevent mutation - Remove unused imports (ThreadPoolExecutor, as_completed) - Add URL validation to reject invalid URLs early - Verified mlx-lm generate() API usage is correct --- inference/run_mlx_react.py | 10 +++++----- inference/tool_visit.py | 14 +++++++++++++- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/inference/run_mlx_react.py b/inference/run_mlx_react.py index 9b87323..c5bbbac 100644 --- a/inference/run_mlx_react.py +++ b/inference/run_mlx_react.py @@ -184,8 +184,8 @@ def execute_tool(self, tool_name: str, tool_args: Dict[str, Any], timeout: int = if tool_name not in TOOL_MAP: return f"Error: Tool '{tool_name}' not found. Available: {list(TOOL_MAP.keys())}" - # Prepare args - tool_args["params"] = tool_args + # Copy args to avoid mutation + args = dict(tool_args) result = "" error = None @@ -193,14 +193,14 @@ def run_tool(): nonlocal result, error try: if "python" in tool_name.lower(): - result = str(TOOL_MAP['PythonInterpreter'].call(tool_args)) + result = str(TOOL_MAP['PythonInterpreter'].call(args)) elif tool_name == "parse_file": import asyncio - params = {"files": tool_args.get("files", [])} + params = {"files": args.get("files", [])} r = asyncio.run(TOOL_MAP[tool_name].call(params, file_root_path="./eval_data/file_corpus")) result = str(r) if not isinstance(r, str) else r else: - result = str(TOOL_MAP[tool_name].call(tool_args)) + result = str(TOOL_MAP[tool_name].call(args)) except Exception as e: error = str(e) diff --git a/inference/tool_visit.py b/inference/tool_visit.py index 26f6e25..7892244 100644 --- a/inference/tool_visit.py +++ b/inference/tool_visit.py @@ -1,7 +1,6 @@ import json import os import re -from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Union, Optional import requests from qwen_agent.tools.base import BaseTool, register_tool @@ -96,6 +95,14 @@ class Visit(BaseTool): "required": ["url", "goal"] } + def _validate_url(self, url: str) -> bool: + """Check if URL is valid and has a proper scheme.""" + try: + parsed = urlparse(url) + return parsed.scheme in ('http', 'https') and bool(parsed.netloc) + except Exception: + return False + def call(self, params: Union[str, dict], **kwargs) -> str: try: url = params["url"] @@ -104,6 +111,8 @@ def call(self, params: Union[str, dict], **kwargs) -> str: return "[Visit] Invalid request: need 'url' and 'goal' fields" if isinstance(url, str): + if not self._validate_url(url): + return f"[Visit] Invalid URL: {url}. URL must start with http:// or https://" return self.readpage(url, goal) # Multiple URLs @@ -113,6 +122,9 @@ def call(self, params: Union[str, dict], **kwargs) -> str: if time.time() - start > 300: # 5 min timeout for batch responses.append(f"[Timeout] Skipped: {u}") continue + if not self._validate_url(u): + responses.append(f"[Visit] Invalid URL: {u}") + continue try: responses.append(self.readpage(u, goal)) except Exception as e: