diff --git a/.env.example b/.env.example index 5a80abea..48e82544 100644 --- a/.env.example +++ b/.env.example @@ -127,6 +127,9 @@ REFRAG_SENSE=heuristic LLAMACPP_URL=http://llamacpp:8080 REFRAG_DECODER_MODE=prompt # prompt|soft +# GLM_API_BASE=https://api.z.ai/api/coding/paas/v4/ +# GLM_MODEL=glm-4.6 + # GPU Performance Toggle # Set to 1 to use native GPU-accelerated server on localhost:8081 # Set to 0 to use Docker CPU-only server (default, stable) diff --git a/Dockerfile.indexer b/Dockerfile.indexer index 994345fc..3d35a027 100644 --- a/Dockerfile.indexer +++ b/Dockerfile.indexer @@ -8,7 +8,9 @@ ENV PYTHONDONTWRITEBYTECODE=1 \ # OS packages needed: git for history ingestion RUN apt-get update && apt-get install -y --no-install-recommends git ca-certificates && rm -rf /var/lib/apt/lists/* -RUN pip install --no-cache-dir qdrant-client fastembed watchdog onnxruntime tokenizers tree_sitter tree_sitter_languages +# Python deps: reuse shared requirements file +COPY requirements.txt /tmp/requirements.txt +RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt # Bake scripts into the image so we can mount arbitrary code at /work COPY scripts /app/scripts diff --git a/Dockerfile.mcp-indexer b/Dockerfile.mcp-indexer index 2fdb14cf..dcd2aa9f 100644 --- a/Dockerfile.mcp-indexer +++ b/Dockerfile.mcp-indexer @@ -10,9 +10,9 @@ ENV PYTHONDONTWRITEBYTECODE=1 \ RUN apt-get update && apt-get install -y --no-install-recommends git ca-certificates \ && rm -rf /var/lib/apt/lists/* -# Python deps: include FastMCP with Streamable HTTP (RMCP) support -RUN pip install --no-cache-dir --upgrade qdrant-client fastembed watchdog onnxruntime tokenizers \ - tree_sitter tree_sitter_languages mcp fastmcp +# Python deps: reuse shared requirements (includes FastMCP + OpenAI SDK) +COPY requirements.txt /tmp/requirements.txt +RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt # Bake scripts into the image so entrypoints don't rely on /work COPY scripts /app/scripts diff --git a/ctx-hook-simple.sh b/ctx-hook-simple.sh new file mode 100755 index 00000000..f21b8a0e --- /dev/null +++ b/ctx-hook-simple.sh @@ -0,0 +1,82 @@ +#!/bin/bash + +# Simplified Claude Code UserPromptSubmit hook for ctx.py +# Takes JSON input from Claude Code and outputs enhanced prompt + +# Read JSON input from stdin +INPUT=$(cat) + +# Extract the user message using jq +if command -v jq >/dev/null 2>&1; then + USER_MESSAGE=$(echo "$INPUT" | jq -r '.user_message') +else + echo "$INPUT" + exit 0 +fi + +# Skip if empty message +if [ -z "$USER_MESSAGE" ] || [ "$USER_MESSAGE" = "null" ]; then + echo "$INPUT" + exit 0 +fi + +# Easy bypass patterns - any of these will skip ctx enhancement +if [[ "$USER_MESSAGE" =~ ^(noctx|raw|bypass|skip|no-enhance): ]] || \ + [[ "$USER_MESSAGE" =~ ^\\ ]] || \ + [[ "$USER_MESSAGE" =~ ^\< ]] || \ + [[ "$USER_MESSAGE" =~ ^(/help|/clear|/exit|/quit) ]] || \ + [[ "$USER_MESSAGE" =~ ^\?\s*$ ]] || \ + [ ${#USER_MESSAGE} -lt 12 ]; then + echo "$INPUT" + exit 0 +fi + +# Set working directory to where the hook script is located +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# Read all settings from ctx_config.json +CONFIG_FILE="ctx_config.json" +if [ -f "$CONFIG_FILE" ]; then + CTX_COLLECTION=$(grep -o '"default_collection"[[:space:]]*:[[:space:]]*"[^"]*"' "$CONFIG_FILE" | sed 's/.*"default_collection"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/') + REFRAG_RUNTIME=$(grep -o '"refrag_runtime"[[:space:]]*:[[:space:]]*"[^"]*"' "$CONFIG_FILE" | sed 's/.*"refrag_runtime"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/' || echo "glm") + GLM_API_KEY=$(grep -o '"glm_api_key"[[:space:]]*:[[:space:]]*"[^"]*"' "$CONFIG_FILE" | sed 's/.*"glm_api_key"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/') + GLM_API_BASE=$(grep -o '"glm_api_base"[[:space:]]*:[[:space:]]*"[^"]*"' "$CONFIG_FILE" | sed 's/.*"glm_api_base"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/') + GLM_MODEL=$(grep -o '"glm_model"[[:space:]]*:[[:space:]]*"[^\"]*"' "$CONFIG_FILE" | sed 's/.*"glm_model"[[:space:]]*:[[:space:]]*"\([^\"]*\)".*/\1/' || echo "glm-4.6") + CTX_DEFAULT_MODE=$(grep -o '"default_mode"[[:space:]]*:[[:space:]]*"[^\"]*"' "$CONFIG_FILE" | sed 's/.*"default_mode"[[:space:]]*:[[:space:]]*"\([^\"]*\)".*/\1/') + CTX_REQUIRE_CONTEXT=$(grep -o '"require_context"[[:space:]]*:[[:space:]]*\(true\|false\)' "$CONFIG_FILE" | sed 's/.*"require_context"[[:space:]]*:[[:space:]]*\(true\|false\).*/\1/') + CTX_RELEVANCE_GATE=$(grep -o '"relevance_gate_enabled"[[:space:]]*:[[:space:]]*\(true\|false\)' "$CONFIG_FILE" | sed 's/.*"relevance_gate_enabled"[[:space:]]*:[[:space:]]*\(true\|false\).*/\1/') + CTX_MIN_RELEVANCE=$(grep -o '"min_relevance"[[:space:]]*:[[:space:]]*[0-9.][0-9.]*' "$CONFIG_FILE" | sed 's/.*"min_relevance"[[:space:]]*:[[:space:]]*\([0-9.][0-9.]*\).*/\1/') +fi + +# Set defaults if not found in config +CTX_COLLECTION=${CTX_COLLECTION:-"codebase"} +REFRAG_RUNTIME=${REFRAG_RUNTIME:-"glm"} +GLM_API_KEY=${GLM_API_KEY:-} +GLM_API_BASE=${GLM_API_BASE:-} +GLM_MODEL=${GLM_MODEL:-"glm-4.6"} +CTX_DEFAULT_MODE=${CTX_DEFAULT_MODE:-"default"} +CTX_REQUIRE_CONTEXT=${CTX_REQUIRE_CONTEXT:-true} +CTX_RELEVANCE_GATE=${CTX_RELEVANCE_GATE:-false} +CTX_MIN_RELEVANCE=${CTX_MIN_RELEVANCE:-0.1} + +# Export GLM/context environment variables from config +export REFRAG_RUNTIME GLM_API_KEY GLM_API_BASE GLM_MODEL CTX_REQUIRE_CONTEXT CTX_RELEVANCE_GATE CTX_MIN_RELEVANCE + +# Build ctx command with optional unicorn flag +CTX_CMD=(python3 scripts/ctx.py) +case "${CTX_DEFAULT_MODE,,}" in + unicorn) + CTX_CMD+=("--unicorn") + ;; + detail) + CTX_CMD+=("--detail") + ;; +esac +CTX_CMD+=("$USER_MESSAGE" --collection "$CTX_COLLECTION") + +# Run ctx with collection +ENHANCED=$(timeout 30s "${CTX_CMD[@]}" 2>/dev/null || echo "$USER_MESSAGE") + +# Replace user message with enhanced version using jq +echo "$INPUT" | jq --arg enhanced "$ENHANCED" '.user_message = $enhanced' \ No newline at end of file diff --git a/ctx_config.example.json b/ctx_config.example.json index 36115d23..4c3efcb6 100644 --- a/ctx_config.example.json +++ b/ctx_config.example.json @@ -1,7 +1,16 @@ { + "default_collection": "codebase", + "refrag_runtime": "glm", + "glm_api_key": "", + "glm_api_base": "https://api.z.ai/api/coding/paas/v4/", + "glm_model": "glm-4.6", "always_include_tests": true, "prefer_bullet_commands": false, "extra_instructions": "Always consider error handling and edge cases", - "streaming": true + "default_mode": "unicorn", + "streaming": true, + "require_context": true, + "relevance_gate_enabled": false, + "min_relevance": 0.1 } diff --git a/docker-compose.dev-remote.yml b/docker-compose.dev-remote.yml index 27a2a4ca..74f4444a 100644 --- a/docker-compose.dev-remote.yml +++ b/docker-compose.dev-remote.yml @@ -75,6 +75,12 @@ services: - FASTMCP_HOST=${FASTMCP_HOST} - FASTMCP_INDEXER_PORT=${FASTMCP_INDEXER_PORT} - QDRANT_URL=${QDRANT_URL} + - REFRAG_DECODER=${REFRAG_DECODER:-1} + - REFRAG_RUNTIME=${REFRAG_RUNTIME:-llamacpp} + - GLM_API_KEY=${GLM_API_KEY} + - GLM_API_BASE=${GLM_API_BASE:-https://api.z.ai/api/paas/v4/} + - GLM_MODEL=${GLM_MODEL:-glm-4.6} + - LLAMACPP_URL=${LLAMACPP_URL:-http://llamacpp:8080} - COLLECTION_NAME=${COLLECTION_NAME} - PATH_EMIT_MODE=container - HF_HOME=/tmp/huggingface @@ -156,6 +162,12 @@ services: - FASTMCP_INDEXER_PORT=8001 - FASTMCP_TRANSPORT=${FASTMCP_HTTP_TRANSPORT} - QDRANT_URL=${QDRANT_URL} + - REFRAG_DECODER=${REFRAG_DECODER:-1} + - REFRAG_RUNTIME=${REFRAG_RUNTIME:-llamacpp} + - GLM_API_KEY=${GLM_API_KEY} + - GLM_API_BASE=${GLM_API_BASE:-https://api.z.ai/api/paas/v4/} + - GLM_MODEL=${GLM_MODEL:-glm-4.6} + - LLAMACPP_URL=${LLAMACPP_URL:-http://llamacpp:8080} - FASTMCP_HEALTH_PORT=18001 - COLLECTION_NAME=${COLLECTION_NAME} - PATH_EMIT_MODE=container diff --git a/scripts/ctx.py b/scripts/ctx.py index 2e7473d8..92560559 100755 --- a/scripts/ctx.py +++ b/scripts/ctx.py @@ -1,4 +1,5 @@ import re +import difflib #!/usr/bin/env python3 """ @@ -55,6 +56,11 @@ from urllib.error import HTTPError, URLError from typing import Dict, Any, List, Optional, Tuple +try: + from scripts.mcp_router import call_tool_http # type: ignore +except ModuleNotFoundError: # pragma: no cover - local execution fallback + from mcp_router import call_tool_http # type: ignore + # Configuration from environment MCP_URL = os.environ.get("MCP_INDEXER_URL", "http://localhost:8003/mcp") DEFAULT_LIMIT = int(os.environ.get("CTX_LIMIT", "5")) @@ -130,6 +136,11 @@ def get_session_id(timeout: int = 10) -> str: session_id = resp.headers.get("mcp-session-id") if not session_id: raise RuntimeError("Server did not return session ID") + # Read the initialization response to ensure session is fully established + init_response = resp.read().decode('utf-8') + # Wait a moment for session to be fully processed + import time + time.sleep(0.5) _session_id = session_id return session_id except Exception as e: @@ -138,8 +149,6 @@ def get_session_id(timeout: int = 10) -> str: def call_mcp_tool(tool_name: str, params: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]: """Call MCP tool via HTTP JSON-RPC with session management.""" - session_id = get_session_id() - payload = { "jsonrpc": "2.0", "id": 1, @@ -147,23 +156,12 @@ def call_mcp_tool(tool_name: str, params: Dict[str, Any], timeout: int = 30) -> "params": {"name": tool_name, "arguments": params} } + # Debug output + sys.stderr.write(f"[DEBUG] Sending payload: {json.dumps(payload, indent=2)}\n") + sys.stderr.flush() + try: - req = request.Request( - MCP_URL, - data=json.dumps(payload).encode(), - headers={ - "Content-Type": "application/json", - "Accept": "application/json, text/event-stream", - "mcp-session-id": session_id - } - ) - with request.urlopen(req, timeout=timeout) as resp: - response_text = resp.read().decode('utf-8') - return parse_sse_response(response_text) - except HTTPError as e: - return {"error": f"HTTP {e.code}: {e.reason}"} - except URLError as e: - return {"error": f"Connection failed: {e.reason}"} + return call_tool_http(MCP_URL, tool_name, params, timeout=float(timeout)) except Exception as e: return {"error": f"Request failed: {str(e)}"} @@ -281,7 +279,17 @@ def _ensure_two_paragraph_questions(text: str) -> str: # Collapse triple+ newlines to double while "\n\n\n" in t: t = t.replace("\n\n\n", "\n\n") - paras = [p.strip() for p in t.split("\n\n") if p.strip()] + raw_paras = [p.strip() for p in t.split("\n\n") if p.strip()] + + # Deduplicate paragraphs (case/whitespace insensitive, tolerance for near-duplicates) + paras: list[str] = [] + dedup_keys: list[str] = [] + for p in raw_paras: + key = re.sub(r"\s+", " ", p).strip().lower() + if any(difflib.SequenceMatcher(None, key, existing).ratio() >= 0.99 for existing in dedup_keys): + continue + dedup_keys.append(key) + paras.append(p) def normalize_paragraph(s: str) -> str: """Ensure proper punctuation - keep questions as questions, commands as commands.""" @@ -304,9 +312,10 @@ def normalize_paragraph(s: str) -> str: return s[:-1].rstrip() + "." return s + "." + max_paragraphs = 3 if len(paras) >= 2: - p1, p2 = normalize_paragraph(paras[0]), normalize_paragraph(paras[1]) - return p1 + "\n\n" + p2 + selected = [normalize_paragraph(p) for p in paras[:max_paragraphs]] + return "\n\n".join(selected) # Single paragraph: try to split by sentence boundary p = paras[0] if paras else t @@ -319,7 +328,7 @@ def normalize_paragraph(s: str) -> str: else: p1 = p.strip() p2 = ( - "Additionally, clarify algorithmic steps, inputs/outputs, configuration parameters, performance considerations, error handling behavior, tests, and edge cases relevant to the referenced components" + "Detail the exact systems involved (e.g., files, classes, state machines), how data flows between them, and any validation before emitting updates." ) return normalize_paragraph(p1) + "\n\n" + normalize_paragraph(p2) @@ -370,6 +379,50 @@ def build_refined_query(original_query: str, allowed_paths: Set[str], allowed_sy return (original_query or "").strip() + (" " + " ".join(terms) if terms else "") +def _simple_tokenize(text: str) -> List[str]: + tokens = re.findall(r"[A-Za-z0-9_]+", text or "") + return [t.lower() for t in tokens if t] + + +def _token_overlap_ratio(a: str, b: str) -> float: + a_tokens = set(_simple_tokenize(a)) + b_tokens = set(_simple_tokenize(b)) + if not a_tokens or not b_tokens: + return 0.0 + inter = len(a_tokens & b_tokens) + union = len(a_tokens | b_tokens) + if not union: + return 0.0 + return inter / union + + +def _estimate_query_result_relevance(query: str, results: List[Dict[str, Any]]) -> float: + q_tokens = set(_simple_tokenize(query)) + if not q_tokens or not results: + return 0.0 + scores: List[float] = [] + for hit in results[:5]: + parts: List[str] = [] + for key in ("path", "symbol", "snippet"): + val = hit.get(key) + if isinstance(val, str): + parts.append(val) + if not parts: + continue + r_tokens = set() + for part in parts: + r_tokens.update(_simple_tokenize(part)) + if not r_tokens: + continue + inter = len(q_tokens & r_tokens) + union = len(q_tokens | r_tokens) + if union: + scores.append(inter / union) + if not scores: + return 0.0 + return sum(scores) / len(scores) + + def sanitize_citations(text: str, allowed_paths: Set[str]) -> str: """Replace path-like strings not present in allowed_paths with a neutral phrase. @@ -482,6 +535,19 @@ def enhance_prompt(query: str, **filters) -> str: filters = _adaptive_context_sizing(query, filters) context_text, context_note = fetch_context(query, **filters) + + require_ctx_flag = os.environ.get("CTX_REQUIRE_CONTEXT", "").strip().lower() + if require_ctx_flag in {"1", "true", "yes", "on"}: + has_real_context = bool((context_text or "").strip()) and not ( + context_note and ( + "failed" in context_note.lower() + or "no relevant" in context_note.lower() + or "no data" in context_note.lower() + ) + ) + if not has_real_context: + return (query or "").strip() + rewrite_opts = filters.get("rewrite_options") or {} rewritten = rewrite_prompt( query, @@ -625,6 +691,38 @@ def _needs_polish(text: str) -> bool: return False +def _dedup_paragraphs(text: str, max_paragraphs: int = 3) -> str: + """Deterministic paragraph-level deduplication and truncation. + + - Split on double-newline boundaries + - Drop duplicate paragraphs beyond the first occurrence (case/whitespace insensitive) + - Cap total paragraphs to max_paragraphs + """ + if not text: + return "" + + # Normalize newlines and split into paragraphs + t = text.replace("\r\n", "\n").replace("\r", "\n").strip() + raw_paras = [p.strip() for p in t.split("\n\n") if p.strip()] + if not raw_paras: + return text.strip() + + seen_keys: set[str] = set() + out: list[str] = [] + for p in raw_paras: + key = re.sub(r"\s+", " ", p).strip().lower() + if key in seen_keys: + continue + seen_keys.add(key) + out.append(p) + if len(out) >= max_paragraphs: + break + + if not out: + return text.strip() + return "\n\n".join(out) + + def enhance_unicorn(query: str, **filters) -> str: """Multi-pass staged enhancement for higher quality with optional plan generation. @@ -661,6 +759,20 @@ def enhance_unicorn(query: str, **filters) -> str: allowed_paths1, _ = extract_allowed_citations(ctx1) refined_query = draft + overlap = _token_overlap_ratio(query, draft) + sys.stderr.write(f"[DEBUG] Unicorn draft similarity={overlap:.3f}\n") + sys.stderr.flush() + gate_flag = os.environ.get("CTX_DRAFT_SIM_GATE", "").strip().lower() + if gate_flag in {"1", "true", "yes", "on"}: + try: + min_sim = float(os.environ.get("CTX_MIN_DRAFT_SIM", "0.4")) + except Exception: + min_sim = 0.4 + if overlap < min_sim: + sys.stderr.write(f"[DEBUG] Draft similarity below threshold {min_sim:.3f}; reusing original query for pass2.\n") + sys.stderr.flush() + refined_query = query + # ---- Pass 2: refine (even richer snippets, focused results) f2 = dict(filters) f2.update({ @@ -724,13 +836,15 @@ def fetch_context(query: str, **filters) -> Tuple[str, str]: Falls back to context_search (with memories) if repo_search returns no hits. """ with_snippets = bool(filters.get("with_snippets", False)) + # Resolve collection: explicit filter wins, then env COLLECTION_NAME, then default "codebase" + collection_name = filters.get("collection") or os.environ.get("COLLECTION_NAME", "codebase") + params = { "query": query, "limit": filters.get("limit", DEFAULT_LIMIT), "include_snippet": with_snippets, "context_lines": filters.get("context_lines", DEFAULT_CONTEXT_LINES), - "per_path": filters.get("per_path", DEFAULT_PER_PATH), - "collection": "codebase", # Use the correct collection name + "collection": collection_name, } for key in ["language", "under", "path_glob", "not_glob", "kind", "symbol", "ext"]: if filters.get(key): @@ -750,9 +864,21 @@ def fetch_context(query: str, **filters) -> Tuple[str, str]: return "", "Context retrieval returned no data." hits = data.get("results") or [] - sys.stderr.write(f"[DEBUG] repo_search returned {len(hits)} hits\n") + relevance = _estimate_query_result_relevance(query, hits) + sys.stderr.write(f"[DEBUG] repo_search returned {len(hits)} hits (relevance={relevance:.3f})\n") sys.stderr.flush() + gate_flag = os.environ.get("CTX_RELEVANCE_GATE", "").strip().lower() + if hits and gate_flag in {"1", "true", "yes", "on"}: + try: + min_rel = float(os.environ.get("CTX_MIN_RELEVANCE", "0.15")) + except Exception: + min_rel = 0.15 + if relevance < min_rel: + sys.stderr.write(f"[DEBUG] Relevance below threshold {min_rel:.3f}; treating as no relevant context.\n") + sys.stderr.flush() + return "", "No relevant context found for the prompt (low retrieval relevance)." + if not hits: # Memory blending: try context_search with memories as fallback memory_params = { @@ -761,7 +887,7 @@ def fetch_context(query: str, **filters) -> Tuple[str, str]: "include_memories": True, "include_snippet": with_snippets, "context_lines": filters.get("context_lines", DEFAULT_CONTEXT_LINES), - "collection": "codebase", # Use the correct collection name + "collection": collection_name, } memory_result = call_mcp_tool("context_search", memory_params) if "error" not in memory_result: @@ -776,7 +902,7 @@ def fetch_context(query: str, **filters) -> Tuple[str, str]: def rewrite_prompt(original_prompt: str, context: str, note: str, max_tokens: Optional[int], citation_policy: str = "paths", stream: bool = True) -> str: - """Use the local decoder (llama.cpp) to rewrite the prompt with repository context. + """Use the configured decoder (GLM or llama.cpp) to rewrite the prompt with repository context. Returns ONLY the improved prompt text. Raises exception if decoder fails. If stream=True (default), prints tokens as they arrive for instant feedback. @@ -872,78 +998,136 @@ def rewrite_prompt(original_prompt: str, context: str, note: str, max_tokens: Op if prefs.get("streaming") is not None: stream = prefs.get("streaming") - meta_prompt = ( - "<|start_of_role|>system<|end_of_role|>" + system_msg + "<|end_of_text|>\n" - "<|start_of_role|>user<|end_of_role|>" + user_msg + "<|end_of_text|>\n" - "<|start_of_role|>assistant<|end_of_role|>" - ) + # Check which decoder runtime to use + runtime_kind = str(os.environ.get("REFRAG_RUNTIME", "llamacpp")).strip().lower() - decoder_url = DECODER_URL - # Safety: only allow local decoder hosts - parsed = urlparse(decoder_url) - if parsed.hostname not in {"localhost", "127.0.0.1", "host.docker.internal"}: - raise ValueError(f"Unsafe decoder host: {parsed.hostname}") - payload = { - "prompt": meta_prompt, - "n_predict": int(max_tokens or DEFAULT_REWRITE_TOKENS), - "temperature": 0.45, - "stream": stream, - } + if runtime_kind == "glm": + from refrag_glm import GLMRefragClient # type: ignore + client = GLMRefragClient() - req = request.Request( - decoder_url, - data=json.dumps(payload).encode("utf-8"), - headers={"Content-Type": "application/json"}, - ) + # GLM uses OpenAI-style chat completions, convert context to user prompt format + # Note: For GLM, we need to convert the meta_prompt format to simple user message + user_msg = ( + f"Context refs:\n{effective_context}\n\n" + f"Original prompt: {(original_prompt or '').strip()}\n\n" + "Rewrite this as a more specific, detailed prompt using at least two short paragraphs separated by a blank line. " + ) + + if has_code_context: + user_msg += ( + "Use the context above to make the rewrite concrete and specific. " + "For questions: make them more specific and multi-faceted (each paragraph should be a question ending with '?'). " + "For commands/instructions: make them more detailed and concrete (specify exact functions, parameters, edge cases to handle). " + ) + else: + user_msg += ( + "Since no code context is available, keep the rewrite general and exploratory. " + "Do NOT invent specific file paths, line numbers, or function names. " + "For questions: expand into related conceptual questions. For commands/instructions: provide general guidance about the task. " + ) + + # GLM API call + response = client.client.chat.completions.create( + model=os.environ.get("GLM_MODEL", "glm-4.6"), + messages=[ + {"role": "system", "content": system_msg}, + {"role": "user", "content": user_msg} + ], + max_tokens=int(max_tokens or DEFAULT_REWRITE_TOKENS), + temperature=0.45, + stream=stream + ) + + enhanced = "" + if stream: + # Streaming mode for GLM + for chunk in response: + if chunk.choices[0].delta.content: + token = chunk.choices[0].delta.content + sys.stdout.write(token) + sys.stdout.flush() + enhanced += token + sys.stdout.write("\n") + sys.stdout.flush() + else: + # Non-streaming mode for GLM + enhanced = response.choices[0].message.content - enhanced = "" - if stream: - # Streaming mode: print tokens as they arrive for instant feedback - with request.urlopen(req, timeout=DECODER_TIMEOUT) as resp: - for line in resp: - line_str = line.decode("utf-8", errors="ignore").strip() - if not line_str or line_str.startswith(":"): - continue - if line_str.startswith("data: "): - line_str = line_str[6:] - try: - chunk = json.loads(line_str) - token = chunk.get("content", "") - if token: - sys.stdout.write(token) - sys.stdout.flush() - enhanced += token - if chunk.get("stop", False): - break - except json.JSONDecodeError as e: - # Warn once per malformed line but keep streaming the final output only - sys.stderr.write(f"[WARN] decoder stream JSON decode failed: {str(e)}\n") - sys.stderr.flush() - continue - sys.stdout.write("\n") - sys.stdout.flush() else: - # Non-streaming mode: wait for full response - with request.urlopen(req, timeout=DECODER_TIMEOUT) as resp: - raw = resp.read().decode("utf-8", errors="ignore") - data = json.loads(raw) + # Use llama.cpp decoder (original logic) + meta_prompt = ( + "<|start_of_role|>system<|end_of_role|>" + system_msg + "<|end_of_text|>\n" + "<|start_of_role|>user<|end_of_role|>" + user_msg + "<|end_of_text|>\n" + "<|start_of_role|>assistant<|end_of_role|>" + ) - # Extract content from llama.cpp response - enhanced = ( - (data.get("content") if isinstance(data, dict) else None) - or ((data.get("choices") or [{}])[0].get("content") if isinstance(data, dict) else None) - or ((data.get("choices") or [{}])[0].get("text") if isinstance(data, dict) else None) - or (data.get("generated_text") if isinstance(data, dict) else None) - or (data.get("text") if isinstance(data, dict) else None) - ) + decoder_url = DECODER_URL + # Safety: only allow local decoder hosts + parsed = urlparse(decoder_url) + if parsed.hostname not in {"localhost", "127.0.0.1", "host.docker.internal"}: + raise ValueError(f"Unsafe decoder host: {parsed.hostname}") + payload = { + "prompt": meta_prompt, + "n_predict": int(max_tokens or DEFAULT_REWRITE_TOKENS), + "temperature": 0.45, + "stream": stream, + } + + req = request.Request( + decoder_url, + data=json.dumps(payload).encode("utf-8"), + headers={"Content-Type": "application/json"}, + ) + + enhanced = "" + if stream: + # Streaming mode: print tokens as they arrive for instant feedback + with request.urlopen(req, timeout=DECODER_TIMEOUT) as resp: + for line in resp: + line_str = line.decode("utf-8", errors="ignore").strip() + if not line_str or line_str.startswith(":"): + continue + if line_str.startswith("data: "): + line_str = line_str[6:] + try: + chunk = json.loads(line_str) + token = chunk.get("content", "") + if token: + sys.stdout.write(token) + sys.stdout.flush() + enhanced += token + if chunk.get("stop", False): + break + except json.JSONDecodeError as e: + # Warn once per malformed line but keep streaming the final output only + sys.stderr.write(f"[WARN] decoder stream JSON decode failed: {str(e)}\n") + sys.stderr.flush() + continue + sys.stdout.write("\n") + sys.stdout.flush() + else: + # Non-streaming mode: wait for full response + with request.urlopen(req, timeout=DECODER_TIMEOUT) as resp: + raw = resp.read().decode("utf-8", errors="ignore") + data = json.loads(raw) + + # Extract content from llama.cpp response + enhanced = ( + (data.get("content") if isinstance(data, dict) else None) + or ((data.get("choices") or [{}])[0].get("content") if isinstance(data, dict) else None) + or ((data.get("choices") or [{}])[0].get("text") if isinstance(data, dict) else None) + or (data.get("generated_text") if isinstance(data, dict) else None) + or (data.get("text") if isinstance(data, dict) else None) + ) enhanced = (enhanced or "").replace("```", "").replace("`", "").strip() if not enhanced: raise ValueError("Decoder returned empty response") - # Enforce at least two question paragraphs + # Enforce at least two question paragraphs, then deduplicate and cap paragraphs enhanced = _ensure_two_paragraph_questions(enhanced) + enhanced = _dedup_paragraphs(enhanced, max_paragraphs=3) return enhanced @@ -1011,6 +1195,7 @@ def main(): parser.add_argument("--kind", help="Filter by symbol kind (e.g., function, class)") parser.add_argument("--symbol", help="Filter by symbol name") parser.add_argument("--ext", help="Filter by file extension") + parser.add_argument("--collection", help="Override collection name (default: env COLLECTION_NAME)") # Output control parser.add_argument("--limit", type=int, default=DEFAULT_LIMIT, @@ -1039,6 +1224,7 @@ def main(): "kind": args.kind, "symbol": args.symbol, "ext": args.ext, + "collection": args.collection, "per_path": args.per_path, "with_snippets": args.detail, "rewrite_options": { @@ -1066,8 +1252,24 @@ def main(): output = enhance_unicorn(args.query, **filters) else: context_text, context_note = fetch_context(args.query, **filters) - rewritten = rewrite_prompt(args.query, context_text, context_note, max_tokens=args.rewrite_max_tokens) - output = rewritten.strip() + + require_ctx_flag = os.environ.get("CTX_REQUIRE_CONTEXT", "").strip().lower() + if require_ctx_flag in {"1", "true", "yes", "on"}: + has_real_context = bool((context_text or "").strip()) and not ( + context_note and ( + "failed" in context_note.lower() + or "no relevant" in context_note.lower() + or "no data" in context_note.lower() + ) + ) + if not has_real_context: + output = (args.query or "").strip() + else: + rewritten = rewrite_prompt(args.query, context_text, context_note, max_tokens=args.rewrite_max_tokens) + output = rewritten.strip() + else: + rewritten = rewrite_prompt(args.query, context_text, context_note, max_tokens=args.rewrite_max_tokens) + output = rewritten.strip() if args.cmd: subprocess.run(args.cmd, input=output.encode("utf-8"), shell=True, check=False) diff --git a/scripts/mcp_indexer_server.py b/scripts/mcp_indexer_server.py index 95b047be..4690b584 100644 --- a/scripts/mcp_indexer_server.py +++ b/scripts/mcp_indexer_server.py @@ -5965,11 +5965,24 @@ def _ca_build_prompt( def _ca_decode( - prompt: str, *, mtok: int, temp: float, top_k: int, top_p: float, stops: list[str] + prompt: str, + *, + mtok: int, + temp: float, + top_k: int, + top_p: float, + stops: list[str], + timeout: float | None = None, ) -> str: - from scripts.refrag_llamacpp import LlamaCppRefragClient # type: ignore + runtime_kind = str(os.environ.get("REFRAG_RUNTIME", "llamacpp")).strip().lower() + if runtime_kind == "glm": + from scripts.refrag_glm import GLMRefragClient # type: ignore + + client = GLMRefragClient() + else: + from scripts.refrag_llamacpp import LlamaCppRefragClient # type: ignore - client = LlamaCppRefragClient() + client = LlamaCppRefragClient() base_tokens = int(max(16, mtok)) last_err: Optional[Exception] = None import time as _time @@ -5979,16 +5992,41 @@ def _ca_decode( base_tokens if attempt == 0 else max(16, base_tokens // (2 if attempt == 1 else 3)) ) try: - return client.generate_with_soft_embeddings( - prompt=prompt, - max_tokens=cur_tokens, - temperature=temp, - top_k=top_k, - top_p=top_p, - stop=stops, - repeat_penalty=float(os.environ.get("DECODER_REPEAT_PENALTY", "1.15") or 1.15), - repeat_last_n=int(os.environ.get("DECODER_REPEAT_LAST_N", "128") or 128), - ) + gen_kwargs = { + "max_tokens": cur_tokens, + "temperature": temp, + "top_p": top_p, + "stop": stops, + } + if runtime_kind == "glm": + timeout_value: Optional[float] = None + if timeout is not None: + try: + timeout_value = float(timeout) + except Exception: + timeout_value = None + if timeout_value is None: + raw_timeout = os.environ.get("GLM_TIMEOUT_SEC", "").strip() + if raw_timeout: + try: + timeout_value = float(raw_timeout) + except Exception: + timeout_value = None + if timeout_value is not None: + gen_kwargs["timeout"] = timeout_value + else: + gen_kwargs.update( + { + "top_k": top_k, + "repeat_penalty": float( + os.environ.get("DECODER_REPEAT_PENALTY", "1.15") or 1.15 + ), + "repeat_last_n": int( + os.environ.get("DECODER_REPEAT_LAST_N", "128") or 128 + ), + } + ) + return client.generate_with_soft_embeddings(prompt=prompt, **gen_kwargs) except Exception as e: last_err = e # Allow quick retries with reduced budget and tiny backoff to rescue transient 5xx @@ -6908,7 +6946,13 @@ def _k(s: Dict[str, Any]): ) with _env_overrides({"LLAMACPP_TIMEOUT_SEC": str(_llama_timeout)}): answer = _ca_decode( - prompt, mtok=mtok, temp=temp, top_k=top_k, top_p=top_p, stops=stops + prompt, + mtok=mtok, + temp=temp, + top_k=top_k, + top_p=top_p, + stops=stops, + timeout=_llama_timeout, ) # Post-process and validate @@ -7010,6 +7054,7 @@ def _tok2(s: str) -> list[str]: top_k=top_k, top_p=top_p, stops=stops, + timeout=_llama_timeout, ) # Minimal post-processing with per-query identifier inference diff --git a/scripts/refrag_glm.py b/scripts/refrag_glm.py index 3e60a274..7905c4b2 100644 --- a/scripts/refrag_glm.py +++ b/scripts/refrag_glm.py @@ -38,6 +38,11 @@ def generate_with_soft_embeddings( temperature = float(gen_kwargs.get("temperature", 0.2)) top_p = float(gen_kwargs.get("top_p", 0.95)) stop = gen_kwargs.get("stop") + timeout = gen_kwargs.pop("timeout", None) + try: + timeout_val = float(timeout) if timeout is not None else None + except Exception: + timeout_val = None try: response = self.client.chat.completions.create( @@ -47,6 +52,7 @@ def generate_with_soft_embeddings( temperature=temperature, top_p=top_p, stop=stop if stop else None, + timeout=timeout_val, ) msg = response.choices[0].message # GLM-4.6 uses reasoning_content for thinking models