From f4e0321072bf70eca3d7c6f4418cfb21c6f97100 Mon Sep 17 00:00:00 2001 From: blocks Date: Tue, 24 Mar 2026 17:53:17 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20Grok=20=E4=B8=AD?= =?UTF-8?q?=E8=BD=AC=E7=AB=99=E7=A9=BA=E7=BB=93=E6=9E=9C=E5=B9=B6=E5=A2=9E?= =?UTF-8?q?=E5=BC=BA=E8=A7=84=E5=88=92=E4=B8=8E=E6=8A=93=E5=8F=96=E5=AE=B9?= =?UTF-8?q?=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/grok_search/config.py | 8 ++ src/grok_search/logger.py | 1 - src/grok_search/planning.py | 39 ++++++ src/grok_search/providers/grok.py | 179 +++++++++++++++++++----- src/grok_search/server.py | 221 +++++++++++++++++++++++++----- src/grok_search/sources.py | 84 ++++++++++++ 6 files changed, 465 insertions(+), 67 deletions(-) diff --git a/src/grok_search/config.py b/src/grok_search/config.py index bdfbfd6..ea88507 100644 --- a/src/grok_search/config.py +++ b/src/grok_search/config.py @@ -63,6 +63,13 @@ def retry_multiplier(self) -> float: def retry_max_wait(self) -> int: return int(os.getenv("GROK_RETRY_MAX_WAIT", "10")) + @property + def output_cleanup_enabled(self) -> bool: + raw = os.getenv("GROK_OUTPUT_CLEANUP") + if raw is None: + raw = os.getenv("GROK_FILTER_THINK_TAGS", "true") + return raw.lower() in ("true", "1", "yes") + @property def grok_api_url(self) -> str: url = os.getenv("GROK_API_URL") @@ -184,6 +191,7 @@ def get_config_info(self) -> dict: "GROK_API_KEY": api_key_masked, "GROK_MODEL": self.grok_model, "GROK_DEBUG": self.debug_enabled, + "GROK_OUTPUT_CLEANUP": self.output_cleanup_enabled, "GROK_LOG_LEVEL": self.log_level, "GROK_LOG_DIR": str(self.log_dir), "TAVILY_API_URL": self.tavily_api_url, diff --git a/src/grok_search/logger.py b/src/grok_search/logger.py index 57f711d..4b62d0a 100644 --- a/src/grok_search/logger.py +++ b/src/grok_search/logger.py @@ -1,6 +1,5 @@ import logging from datetime import datetime -from pathlib import Path from .config import config logger = logging.getLogger("grok_search") diff --git a/src/grok_search/planning.py b/src/grok_search/planning.py index 9f67a73..3fc76c2 100644 --- a/src/grok_search/planning.py +++ b/src/grok_search/planning.py @@ -84,6 +84,13 @@ class ExecutionOrderOutput(BaseModel): _ACCUMULATIVE_LIST_PHASES = {"query_decomposition", "tool_selection"} _MERGE_STRATEGY_PHASE = "search_strategy" +_PHASE_PREDECESSORS = { + "complexity_assessment": "intent_analysis", + "query_decomposition": "complexity_assessment", + "search_strategy": "query_decomposition", + "tool_selection": "search_strategy", + "execution_order": "tool_selection", +} def _split_csv(value: str) -> list[str]: @@ -126,6 +133,9 @@ def __init__(self): def get_session(self, session_id: str) -> PlanningSession | None: return self._sessions.get(session_id) + def reset(self) -> None: + self._sessions.clear() + def process_phase( self, phase: str, @@ -147,6 +157,35 @@ def process_phase( if target not in PHASE_NAMES: return {"error": f"Unknown phase: {target}. Valid: {', '.join(PHASE_NAMES)}"} + if not is_revision: + predecessor = _PHASE_PREDECESSORS.get(target) + if predecessor and predecessor not in session.phases: + return { + "error": f"Phase '{target}' requires '{predecessor}' to be completed first.", + "expected_phase_order": PHASE_NAMES, + "session_id": session.session_id, + "completed_phases": session.completed_phases, + "complexity_level": session.complexity_level, + } + + if session.complexity_level == 1 and target in {"search_strategy", "tool_selection", "execution_order"}: + return { + "error": "Level 1 planning completes after query_decomposition.", + "expected_phase_order": PHASE_NAMES, + "session_id": session.session_id, + "completed_phases": session.completed_phases, + "complexity_level": session.complexity_level, + } + + if session.complexity_level == 2 and target == "execution_order": + return { + "error": "Level 2 planning completes after tool_selection.", + "expected_phase_order": PHASE_NAMES, + "session_id": session.session_id, + "completed_phases": session.completed_phases, + "complexity_level": session.complexity_level, + } + if target in _ACCUMULATIVE_LIST_PHASES: if is_revision: session.phases[target] = PhaseRecord( diff --git a/src/grok_search/providers/grok.py b/src/grok_search/providers/grok.py index bd2820a..c9af496 100644 --- a/src/grok_search/providers/grok.py +++ b/src/grok_search/providers/grok.py @@ -5,7 +5,6 @@ from typing import List, Optional from tenacity import AsyncRetrying, retry_if_exception, stop_after_attempt, wait_random_exponential from tenacity.wait import wait_base -from zoneinfo import ZoneInfo from .base import BaseSearchProvider, SearchResult from ..utils import search_prompt, fetch_prompt, url_describe_prompt, rank_sources_prompt from ..logger import log_info @@ -125,11 +124,16 @@ def __init__(self, api_url: str, api_key: str, model: str = "grok-4-fast"): def get_provider_name(self) -> str: return "Grok" - async def search(self, query: str, platform: str = "", min_results: int = 3, max_results: int = 10, ctx=None) -> List[SearchResult]: - headers = { + def _build_api_headers(self) -> dict: + return { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "User-Agent": "grok-search-mcp/0.1.0", } + + async def search(self, query: str, platform: str = "", min_results: int = 3, max_results: int = 10, ctx=None) -> List[SearchResult]: + headers = self._build_api_headers() platform_prompt = "" if platform: @@ -146,18 +150,15 @@ async def search(self, query: str, platform: str = "", min_results: int = 3, max }, {"role": "user", "content": time_context + query + platform_prompt}, ], - "stream": True, + "stream": False, } await log_info(ctx, f"platform_prompt: { query + platform_prompt}", config.debug_enabled) - return await self._execute_stream_with_retry(headers, payload, ctx) + return await self._execute_completion_with_retry(headers, payload, ctx) async def fetch(self, url: str, ctx=None) -> str: - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } + headers = self._build_api_headers() payload = { "model": self.model, "messages": [ @@ -167,19 +168,67 @@ async def fetch(self, url: str, ctx=None) -> str: }, {"role": "user", "content": url + "\n获取该网页内容并返回其结构化Markdown格式" }, ], - "stream": True, + "stream": False, } - return await self._execute_stream_with_retry(headers, payload, ctx) + return await self._execute_completion_with_retry(headers, payload, ctx) + + def _extract_content_from_choice(self, choice: dict) -> str: + if not isinstance(choice, dict): + return "" + + message = choice.get("message", {}) + if isinstance(message, dict): + content = message.get("content", "") + if isinstance(content, str) and content: + return content + + delta = choice.get("delta", {}) + if isinstance(delta, dict): + content = delta.get("content", "") + if isinstance(content, str) and content: + return content + + for key in ("text", "content"): + value = choice.get(key, "") + if isinstance(value, str) and value: + return value + + return "" + + def _is_empty_placeholder_payload(self, data: dict) -> bool: + if not isinstance(data, dict): + return False + + if data.get("choices", object()) is not None: + return False + + return all(not str(data.get(key, "")).strip() for key in ("id", "object", "model")) + + def _build_placeholder_error(self, headers=None) -> ValueError: + request_id = "" + if headers: + request_id = ( + headers.get("x-oneapi-request-id", "") + or headers.get("x-request-id", "") + or headers.get("request-id", "") + ).strip() + + message = "上游返回了空的占位 completion 帧(choices=null),疑似中转站对 Grok chat/completions 的实现异常" + if request_id: + message += f",request_id={request_id}" + return ValueError(message) async def _parse_streaming_response(self, response, ctx=None) -> str: content = "" - full_body_buffer = [] - + full_body_buffer = [] + empty_placeholder_detected = False + response_headers = getattr(response, "headers", None) + async for line in response.aiter_lines(): line = line.strip() if not line: continue - + full_body_buffer.append(line) # 兼容 "data: {...}" 和 "data:{...}" 两种 SSE 格式 @@ -190,24 +239,70 @@ async def _parse_streaming_response(self, response, ctx=None) -> str: # 去掉 "data:" 前缀,并去除可能的空格 json_str = line[5:].lstrip() data = json.loads(json_str) + if self._is_empty_placeholder_payload(data): + empty_placeholder_detected = True + continue choices = data.get("choices", []) - if choices and len(choices) > 0: - delta = choices[0].get("delta", {}) - if "content" in delta: - content += delta["content"] + if isinstance(choices, list) and choices: + chunk = self._extract_content_from_choice(choices[0]) + if chunk: + content += chunk except (json.JSONDecodeError, IndexError): continue - + if not content and full_body_buffer: try: full_text = "".join(full_body_buffer) data = json.loads(full_text) - if "choices" in data and len(data["choices"]) > 0: - message = data["choices"][0].get("message", {}) - content = message.get("content", "") + if self._is_empty_placeholder_payload(data): + empty_placeholder_detected = True + choices = data.get("choices", []) + if isinstance(choices, list) and choices: + content = self._extract_content_from_choice(choices[0]) except json.JSONDecodeError: pass - + + if not content and empty_placeholder_detected: + raise self._build_placeholder_error(response_headers) + + await log_info(ctx, f"content: {content}", config.debug_enabled) + + return content + + async def _parse_completion_response(self, response: httpx.Response, ctx=None) -> str: + content = "" + body_text = response.text or "" + + try: + data = response.json() + except Exception: + data = None + + if isinstance(data, dict): + if self._is_empty_placeholder_payload(data): + raise self._build_placeholder_error(response.headers) + choices = data.get("choices", []) + if isinstance(choices, list) and choices: + content = self._extract_content_from_choice(choices[0]) + + if not content and any(line.lstrip().startswith("data:") for line in body_text.splitlines()): + class _LineResponse: + def __init__(self, text: str, headers): + self._lines = text.splitlines() + self.headers = headers + + async def aiter_lines(self): + for line in self._lines: + yield line + + content = await self._parse_streaming_response(_LineResponse(body_text, response.headers), ctx) + + if not content and body_text.strip(): + normalized = body_text.lower() + if " str: + """执行带重试机制的非流式 HTTP 请求,兼容 JSON completion 与 SSE 文本响应。""" + timeout = httpx.Timeout(connect=6.0, read=120.0, write=10.0, pool=None) + + async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client: + async for attempt in AsyncRetrying( + stop=stop_after_attempt(config.retry_max_attempts + 1), + wait=_WaitWithRetryAfter(config.retry_multiplier, config.retry_max_wait), + retry=retry_if_exception(_is_retryable_exception), + reraise=True, + ): + with attempt: + response = await client.post( + f"{self.api_url}/chat/completions", + headers=headers, + json=payload, + ) + response.raise_for_status() + return await self._parse_completion_response(response, ctx) + async def describe_url(self, url: str, ctx=None) -> dict: """让 Grok 阅读单个 URL 并返回 title + extracts""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } + headers = self._build_api_headers() payload = { "model": self.model, "messages": [ {"role": "system", "content": url_describe_prompt}, {"role": "user", "content": url}, ], - "stream": True, + "stream": False, } - result = await self._execute_stream_with_retry(headers, payload, ctx) + result = await self._execute_completion_with_retry(headers, payload, ctx) title, extracts = url, "" for line in result.strip().splitlines(): if line.startswith("Title:"): @@ -258,19 +370,16 @@ async def describe_url(self, url: str, ctx=None) -> dict: async def rank_sources(self, query: str, sources_text: str, total: int, ctx=None) -> list[int]: """让 Grok 按查询相关度对信源排序,返回排序后的序号列表""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } + headers = self._build_api_headers() payload = { "model": self.model, "messages": [ {"role": "system", "content": rank_sources_prompt}, {"role": "user", "content": f"Query: {query}\n\n{sources_text}"}, ], - "stream": True, + "stream": False, } - result = await self._execute_stream_with_retry(headers, payload, ctx) + result = await self._execute_completion_with_retry(headers, payload, ctx) order: list[int] = [] seen: set[int] = set() for token in result.strip().split(): diff --git a/src/grok_search/server.py b/src/grok_search/server.py index 7754216..0f3c41f 100644 --- a/src/grok_search/server.py +++ b/src/grok_search/server.py @@ -1,15 +1,16 @@ +import asyncio import sys from pathlib import Path +from typing import Annotated, Optional + +from fastmcp import FastMCP, Context +from pydantic import Field # 支持直接运行:添加 src 目录到 Python 路径 src_dir = Path(__file__).parent.parent if str(src_dir) not in sys.path: sys.path.insert(0, str(src_dir)) -from fastmcp import FastMCP, Context -from typing import Annotated, Optional -from pydantic import Field - # 尝试使用绝对导入(支持 mcp run) try: from grok_search.providers.grok import GrokSearchProvider @@ -24,8 +25,6 @@ from .sources import SourcesCache, merge_sources, new_session_id, split_answer_and_sources from .planning import engine as planning_engine, _split_csv -import asyncio - mcp = FastMCP("grok-search") _SOURCES_CACHE = SourcesCache(max_size=256) @@ -71,6 +70,132 @@ async def _get_available_models_cached(api_url: str, api_key: str) -> list[str]: return models +def _planning_session_error(session_id: str) -> str: + import json + + return json.dumps( + { + "error": "session_not_found", + "message": f"Session '{session_id}' not found. Call plan_intent first.", + "expected_phase_order": [ + "intent_analysis", + "complexity_assessment", + "query_decomposition", + "search_strategy", + "tool_selection", + "execution_order", + ], + "restart_from_intent_analysis": True, + }, + ensure_ascii=False, + indent=2, + ) + + +def _extract_request_id(headers) -> str: + if not headers: + return "" + + return ( + headers.get("x-oneapi-request-id", "") + or headers.get("x-request-id", "") + or headers.get("request-id", "") + ).strip() + + +def _extract_error_summary(response) -> str: + if response is None: + return "" + + try: + data = response.json() + except Exception: + data = None + + if isinstance(data, dict): + error = data.get("error", {}) + if isinstance(error, dict): + message = (error.get("message") or "").strip() + if message: + return message + + body_text = (getattr(response, "text", "") or "").strip() + if not body_text: + return "" + + normalized = body_text.lower() + if " str: + import httpx + + if isinstance(exc, httpx.TimeoutException): + return "搜索失败: 上游请求超时,请稍后重试" + + if isinstance(exc, httpx.HTTPStatusError): + status_code = exc.response.status_code + location = exc.response.headers.get("location", "").strip() + request_id = _extract_request_id(exc.response.headers) + summary = _extract_error_summary(exc.response) + if status_code in {301, 302, 303, 307, 308} and location: + message = f"搜索失败: 上游返回 HTTP {status_code} 重定向到 {location},请检查代理认证状态" + else: + message = f"搜索失败: 上游返回 HTTP {status_code}" + if summary: + message += f",摘要={summary}" + if request_id: + message += f",request_id={request_id}" + return message + + message = str(exc).strip() + if message: + return f"搜索失败: {message}" + return "搜索失败: 上游请求异常" + + +def _looks_like_login_page(body_text: str) -> bool: + normalized = (body_text or "").strip().lower() + if " str: + import httpx + + if isinstance(exc, httpx.TimeoutException): + return f"{provider} 请求超时" + + if isinstance(exc, httpx.HTTPStatusError): + status_code = exc.response.status_code + location = exc.response.headers.get("location", "").strip() + request_id = _extract_request_id(exc.response.headers) + summary = _extract_error_summary(exc.response) + if status_code in {301, 302, 303, 307, 308} and location: + message = f"{provider} 返回 HTTP {status_code} 重定向到 {location},请检查认证状态" + elif status_code in {401, 403}: + message = f"{provider} 返回 HTTP {status_code},请检查认证状态" + else: + message = f"{provider} 返回 HTTP {status_code}" + if summary: + message += f",摘要={summary}" + if request_id: + message += f",request_id={request_id}" + return message + + message = str(exc).strip() + if message: + return f"{provider} 请求失败: {message}" + return f"{provider} 请求失败" + + def _extra_results_to_sources( tavily_results: list[dict] | None, firecrawl_results: list[dict] | None, @@ -164,11 +289,14 @@ async def web_search( tavily_count = extra_sources # 并行执行搜索任务 - async def _safe_grok() -> str: + async def _safe_grok() -> tuple[str, str | None]: try: - return await grok_provider.search(query, platform) - except Exception: - return "" + result = await grok_provider.search(query, platform) + except Exception as exc: + return "", _format_grok_error(exc) + if not result or not result.strip(): + return "", "搜索失败: 上游返回空响应,请检查模型或代理配置" + return result, None async def _safe_tavily() -> list[dict] | None: try: @@ -192,7 +320,7 @@ async def _safe_firecrawl() -> list[dict] | None: gathered = await asyncio.gather(*coros) - grok_result: str = gathered[0] or "" + grok_result, grok_error = gathered[0] tavily_results: list[dict] | None = None firecrawl_results: list[dict] | None = None idx = 1 @@ -205,9 +333,10 @@ async def _safe_firecrawl() -> list[dict] | None: answer, grok_sources = split_answer_and_sources(grok_result) extra = _extra_results_to_sources(tavily_results, firecrawl_results) all_sources = merge_sources(grok_sources, extra) + content = answer if answer.strip() else (grok_error or "") await _SOURCES_CACHE.set(session_id, all_sources) - return {"session_id": session_id, "content": answer, "sources_count": len(all_sources)} + return {"session_id": session_id, "content": content, "sources_count": len(all_sources)} @mcp.tool( @@ -233,12 +362,12 @@ async def get_sources( return {"session_id": session_id, "sources": sources, "sources_count": len(sources)} -async def _call_tavily_extract(url: str) -> str | None: +async def _call_tavily_extract(url: str) -> tuple[str | None, str | None]: import httpx api_url = config.tavily_api_url api_key = config.tavily_api_key if not api_key: - return None + return None, None endpoint = f"{api_url.rstrip('/')}/extract" headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} body = {"urls": [url], "format": "markdown"} @@ -246,13 +375,17 @@ async def _call_tavily_extract(url: str) -> str | None: async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(endpoint, headers=headers, json=body) response.raise_for_status() + if _looks_like_login_page(response.text): + return None, "Tavily 返回登录页或认证页面,请检查代理认证状态" data = response.json() if data.get("results") and len(data["results"]) > 0: content = data["results"][0].get("raw_content", "") - return content if content and content.strip() else None - return None - except Exception: - return None + if content and content.strip(): + return content, None + return None, "Tavily 提取成功但内容为空" + return None, "Tavily 提取成功但 results 为空" + except Exception as exc: + return None, _format_fetch_error("Tavily", exc) async def _call_tavily_search(query: str, max_results: int = 6) -> list[dict] | None: @@ -305,15 +438,16 @@ async def _call_firecrawl_search(query: str, limit: int = 14) -> list[dict] | No return None -async def _call_firecrawl_scrape(url: str, ctx=None) -> str | None: +async def _call_firecrawl_scrape(url: str, ctx=None) -> tuple[str | None, str | None]: import httpx api_url = config.firecrawl_api_url api_key = config.firecrawl_api_key if not api_key: - return None + return None, None endpoint = f"{api_url.rstrip('/')}/scrape" headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} max_retries = config.retry_max_attempts + last_error: str | None = None for attempt in range(max_retries): body = { "url": url, @@ -325,15 +459,19 @@ async def _call_firecrawl_scrape(url: str, ctx=None) -> str | None: async with httpx.AsyncClient(timeout=90.0) as client: response = await client.post(endpoint, headers=headers, json=body) response.raise_for_status() + if _looks_like_login_page(response.text): + return None, "Firecrawl 返回登录页或认证页面,请检查代理认证状态" data = response.json() markdown = data.get("data", {}).get("markdown", "") if markdown and markdown.strip(): - return markdown + return markdown, None + last_error = "Firecrawl 返回空 markdown" await log_info(ctx, f"Firecrawl: markdown为空, 重试 {attempt + 1}/{max_retries}", config.debug_enabled) except Exception as e: + last_error = _format_fetch_error("Firecrawl", e) await log_info(ctx, f"Firecrawl error: {e}", config.debug_enabled) - return None - return None + return None, last_error + return None, last_error @mcp.tool( @@ -360,20 +498,28 @@ async def web_fetch( ) -> str: await log_info(ctx, f"Begin Fetch: {url}", config.debug_enabled) - result = await _call_tavily_extract(url) + result, tavily_error = await _call_tavily_extract(url) if result: await log_info(ctx, "Fetch Finished (Tavily)!", config.debug_enabled) return result + if tavily_error: + await log_info(ctx, f"Tavily extract failed: {tavily_error}", config.debug_enabled) await log_info(ctx, "Tavily unavailable or failed, trying Firecrawl...", config.debug_enabled) - result = await _call_firecrawl_scrape(url, ctx) + result, firecrawl_error = await _call_firecrawl_scrape(url, ctx) if result: await log_info(ctx, "Fetch Finished (Firecrawl)!", config.debug_enabled) return result + if firecrawl_error: + await log_info(ctx, f"Firecrawl scrape failed: {firecrawl_error}", config.debug_enabled) await log_info(ctx, "Fetch Failed!", config.debug_enabled) if not config.tavily_api_key and not config.firecrawl_api_key: return "配置错误: TAVILY_API_KEY 和 FIRECRAWL_API_KEY 均未配置" + + errors = [error for error in (tavily_error, firecrawl_error) if error] + if errors: + return f"提取失败: {';'.join(errors)}" return "提取失败: 所有提取服务均未能获取内容" @@ -510,7 +656,7 @@ async def get_config_info() -> str: if model_names: test_result["available_models"] = model_names - except: + except Exception: pass else: test_result["status"] = "⚠️ 连接异常" @@ -713,7 +859,7 @@ async def plan_complexity( ) -> str: import json if not planning_engine.get_session(session_id): - return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."}) + return _planning_session_error(session_id) return json.dumps(planning_engine.process_phase( phase="complexity_assessment", thought=thought, session_id=session_id, is_revision=is_revision, confidence=confidence, @@ -741,7 +887,7 @@ async def plan_sub_query( ) -> str: import json if not planning_engine.get_session(session_id): - return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."}) + return _planning_session_error(session_id) item = {"id": id, "goal": goal, "expected_output": expected_output, "boundary": boundary} if depends_on: item["depends_on"] = _split_csv(depends_on) @@ -771,7 +917,7 @@ async def plan_search_term( ) -> str: import json if not planning_engine.get_session(session_id): - return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."}) + return _planning_session_error(session_id) data = {"search_terms": [{"term": term, "purpose": purpose, "round": round}]} if approach: data["approach"] = approach @@ -800,7 +946,7 @@ async def plan_tool_mapping( ) -> str: import json if not planning_engine.get_session(session_id): - return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."}) + return _planning_session_error(session_id) item = {"sub_query_id": sub_query_id, "tool": tool, "reason": reason} if params_json: try: @@ -829,7 +975,7 @@ async def plan_execution( ) -> str: import json if not planning_engine.get_session(session_id): - return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."}) + return _planning_session_error(session_id) parallel = [_split_csv(g) for g in parallel_groups.split(";") if g.strip()] if parallel_groups else [] seq = _split_csv(sequential) return json.dumps(planning_engine.process_phase( @@ -839,6 +985,17 @@ async def plan_execution( ), ensure_ascii=False, indent=2) +def _configure_windows_event_loop_policy() -> None: + if sys.platform != "win32": + return + + policy_cls = getattr(asyncio, "WindowsSelectorEventLoopPolicy", None) + if policy_cls is None: + return + + asyncio.set_event_loop_policy(policy_cls()) + + def main(): import signal import os @@ -852,6 +1009,8 @@ def handle_shutdown(signum, frame): if sys.platform != 'win32': signal.signal(signal.SIGTERM, handle_shutdown) + _configure_windows_event_loop_policy() + # Windows 父进程监控 if sys.platform == 'win32': import time diff --git a/src/grok_search/sources.py b/src/grok_search/sources.py index 63386e2..304f7d7 100644 --- a/src/grok_search/sources.py +++ b/src/grok_search/sources.py @@ -7,6 +7,7 @@ import asyncio +from .config import config from .utils import extract_unique_urls @@ -23,6 +24,47 @@ _SOURCES_FUNCTION_PATTERN = re.compile( r"(?im)(^|\n)\s*(sources|source|citations|citation|references|reference|citation_card|source_cards|source_card)\s*\(" ) +_THINK_BLOCK_PATTERN = re.compile(r"(?is).*?") +_LEADING_POLICY_PATTERNS = [ + re.compile(r"(?is)^\s*\**\s*i cannot comply\b.*"), + re.compile(r"(?is)^\s*\**\s*i do not accept\b.*"), + re.compile(r"(?is)^\s*\**\s*i do not follow\b.*"), + re.compile(r"(?is)^\s*\**\s*i don't follow\b.*"), + re.compile(r"(?is)^\s*\**\s*i don't adopt\b.*"), + re.compile(r"(?is)^\s*\**\s*refusal\s*[::].*"), + re.compile(r"(?is)^\s*\**\s*refuse to\b.*"), + re.compile(r"(?is)^\s*\**\s*rejected?\b.*"), + re.compile(r"(?is)^\s*\**\s*拒绝执行\b.*"), + re.compile(r"(?is)^\s*\**\s*无法遵循\b.*"), +] +_POLICY_META_KEYWORDS = ( + "cannot comply", + "refuse", + "refusal", + "do not follow", + "don't follow", + "don't adopt", + "override my core", + "core behavior", + "custom rules", + "用户提供的自定义", + "覆盖我的核心", + "核心行为", + "拒绝执行", + "无法遵循", +) +_POLICY_CONTEXT_KEYWORDS = ( + "jailbreak", + "prompt injection", + "system instructions", + "system prompt", + "user-injected", + "注入", + "越狱", + "系统指令", + "系统提示", + "自定义“system”", +) def new_session_id() -> str: @@ -72,6 +114,11 @@ def split_answer_and_sources(text: str) -> tuple[str, list[dict]]: if not raw: return "", [] + if config.output_cleanup_enabled: + cleaned = sanitize_answer_text(raw) + if cleaned: + raw = cleaned + split = _split_function_call_sources(raw) if split: return split @@ -91,6 +138,43 @@ def split_answer_and_sources(text: str) -> tuple[str, list[dict]]: return raw, [] +def sanitize_answer_text(text: str) -> str: + raw = (text or "").strip() + if not raw: + return "" + + cleaned = _THINK_BLOCK_PATTERN.sub("", raw).strip() + paragraphs = _split_paragraphs(cleaned) + filtered = [paragraph for paragraph in paragraphs if not _looks_like_policy_block(paragraph)] + if filtered: + return "\n\n".join(filtered).strip() + return cleaned + + +def _split_paragraphs(text: str) -> list[str]: + parts = [part.strip() for part in re.split(r"\n\s*\n", text) if part.strip()] + return parts or ([text.strip()] if text.strip() else []) + + +def _looks_like_policy_block(text: str) -> bool: + normalized = _normalize_policy_text(text) + if not normalized: + return False + + if any(pattern.match(normalized) for pattern in _LEADING_POLICY_PATTERNS): + return True + + return any(keyword in normalized for keyword in _POLICY_META_KEYWORDS) and any( + keyword in normalized for keyword in _POLICY_CONTEXT_KEYWORDS + ) + + +def _normalize_policy_text(text: str) -> str: + normalized = re.sub(r"[>*_`#-]+", " ", text or "") + normalized = re.sub(r"\s+", " ", normalized) + return normalized.strip().lower() + + def _split_function_call_sources(text: str) -> tuple[str, list[dict]] | None: matches = list(_SOURCES_FUNCTION_PATTERN.finditer(text)) if not matches: