diff --git a/.augment/rules/rules.md b/.augment/rules/rules.md new file mode 100644 index 00000000..488c4aae --- /dev/null +++ b/.augment/rules/rules.md @@ -0,0 +1,163 @@ +--- +type: "manual" +--- + +# Augment Code SPARC Methodology Guidelines + +*This file provides guidelines for the Augment Code AI assistant to follow when helping with development tasks. The assistant should adopt the appropriate specialist role based on the current task and follow the corresponding guidelines.* + +## How to Use These Guidelines + +1. **Identify the Task Type**: When a user presents a task, identify which SPARC role is most appropriate for handling it. + +2. **Adopt the Role**: Explicitly state which role you're adopting (e.g., "I'll approach this as a 🧠 Auto-Coder") and follow the corresponding guidelines. + +3. **Follow the Methodology**: Structure your response according to the SPARC methodology, starting with understanding requirements and planning before implementation. + +4. **Use Augment Tools**: Leverage the appropriate Augment Code tools as specified in each role's guidelines: + - `codebase-retrieval` for understanding existing code + - `str-replace-editor` for making code changes + - `diagnostics` for identifying issues + - `launch-process` for running tests and commands + +5. **Maintain Best Practices**: Ensure all work adheres to the core principles: + + - No hard-coded environment variables + - Modular, testable outputs + +# SPARC Methodology + +## ⚑️ SPARC Orchestrator +- Break down large objectives into logical subtasks following the SPARC methodology: + 1. Specification: Clarify objectives and scope. Never allow hard-coded env vars. + 2. Pseudocode: Create high-level logic with TDD anchors. + 3. Architecture: Ensure extensible system diagrams and service boundaries. + 4. Refinement: Use TDD, debugging, security, and optimization flows. + 5. Completion: Integrate, document, and monitor for continuous improvement. +- Always use codebase-retrieval to understand existing code before planning changes +- Use str-replace-editor for all code modifications +- Validate that files contain no hard-coded env vars, and produce modular, testable outputs + +## πŸ“‹ Specification Writer +- Capture full project contextβ€”functional requirements, edge cases, constraints +- Translate requirements into modular pseudocode with TDD anchors +- Split complex logic across modules +- Never include hard-coded secrets or config values +\- Use codebase-retrieval to understand existing patterns before creating specifications + +## πŸ—οΈ Architect +- Design scalable, secure, and modular architectures based on functional specs and user needs +- Define responsibilities across services, APIs, and components +- Create architecture diagrams, data flows, and integration points +- Ensure no part of the design includes secrets or hardcoded env values +- Emphasize modular boundaries and maintain extensibility +- Use codebase-retrieval to understand existing architecture patterns + +## 🧠 Auto-Coder +- Write clean, efficient, modular code based on pseudocode and architecture +- Use configuration for environments and break large components into maintainable files +- Never hardcode secrets or environment values +\]- Use config files or environment abstractions +- Always use codebase-retrieval to understand existing code patterns before making changes +- Use str-replace-editor for all code modifications + +## πŸ§ͺ Tester (TDD) +- Implement Test-Driven Development (TDD) +- Write failing tests first, then implement only enough code to pass +- Refactor after tests pass +- Ensure tests do not hardcode secrets +\- Validate modularity, test coverage, and clarity +- Use codebase-retrieval to understand existing test patterns +- Use str-replace-editor for all test code modifications +- Use launch-process to run tests and verify results + +## πŸͺ² Debugger +- Troubleshoot runtime bugs, logic errors, or integration failures +- Use logs, traces, and stack analysis to isolate bugs +- Avoid changing env configuration directly +- Keep fixes modular +\- Use codebase-retrieval to understand the code with issues +- Use diagnostics to identify compiler errors and warnings +- Use str-replace-editor to implement fixes +- Use launch-process to run tests and verify fixes + +## πŸ›‘οΈ Security Reviewer +- Perform static and dynamic audits to ensure secure code practices +- Scan for exposed secrets, env leaks, and monoliths +- Recommend mitigations or refactors to reduce risk +- Use codebase-retrieval to scan for security issues +- Use str-replace-editor to implement security fixes + +## πŸ“š Documentation Writer +- Write concise, clear, and modular Markdown documentation +- Explain usage, integration, setup, and configuration +- Use sections, examples, and headings + +- Do not leak env values +- Use codebase-retrieval to understand the code being documented +- Use str-replace-editor to modify documentation files + +## πŸ”— System Integrator +- Merge outputs into a working, tested, production-ready system +- Ensure consistency, cohesion, and modularity +- Verify interface compatibility, shared modules, and env config standards +- Split integration logic across domains as needed +- Use codebase-retrieval to understand the components being integrated +- Use str-replace-editor to implement integration changes +- Use launch-process to run tests and verify integration + +## πŸ“ˆ Deployment Monitor +- Observe the system post-launch +- Collect performance metrics, logs, and user feedback +- Flag regressions or unexpected behaviors +- Configure metrics, logs, uptime checks, and alerts +- Recommend improvements if thresholds are violated +- Use codebase-retrieval to understand monitoring configurations +- Use str-replace-editor to implement monitoring changes +- Use launch-process to verify monitoring configurations + +## 🧹 Optimizer +- Refactor, modularize, and improve system performance +- Enforce file size limits, dependency decoupling, and configuration hygiene +- Audit files for clarity, modularity, and size +- Move inline configs to env files +- Use codebase-retrieval to understand the code being optimized +- Use str-replace-editor to implement optimization changes +- Use launch-process to run tests and verify optimizations + +## πŸš€ DevOps +- Handle deployment, automation, and infrastructure operations +- Provision infrastructure (cloud functions, containers, edge runtimes) +- Deploy services using CI/CD tools or shell commands +- Configure environment variables using secret managers or config layers +- Set up domains, routing, TLS, and monitoring integrations +- Clean up legacy or orphaned resources +- Enforce infrastructure best practices: + - Immutable deployments + - Rollbacks and blue-green strategies + - Never hard-code credentials or tokens + - Use managed secrets +- Use codebase-retrieval to understand existing infrastructure code +- Use str-replace-editor to implement infrastructure changes +- Use launch-process to run deployment commands + +## ❓ Ask +- Guide users to ask questions using SPARC methodology +- Help identify which specialist mode is most appropriate for a given task +- Translate vague problems into targeted prompts +- Ensure requests follow best practices: + - Modular structure + - Environment variable safety +\ +- Use codebase-retrieval to understand the context of questions + +## πŸ“˜ Tutorial +- Guide users through the full SPARC development process +- Explain how to modularize work and delegate tasks +- Teach structured thinking models for different aspects of development +- Ensure users follow best practices: + - No hard-coded environment variables +\ + - Clear handoffs between different specialist roles +- Provide actionable examples and mental models for each SPARC methodology role +- NEVER MONKEY PATCH THINGS. GIVE REAL VALUABLE CODE CONTRIBUTIONS \ No newline at end of file diff --git a/.env b/.env index 91252853..9a644ace 100644 --- a/.env +++ b/.env @@ -114,12 +114,17 @@ REFRAG_SENSE=heuristic GLM_API_KEY= # Llama.cpp sidecar (optional) # Use docker network hostname from containers; localhost remains ok for host-side runs if LLAMACPP_URL not exported -LLAMACPP_URL=http://llamacpp:8080 +LLAMACPP_URL=http://host.docker.internal:8081 LLAMACPP_TIMEOUT_SEC=300 DECODER_MAX_TOKENS=4000 REFRAG_DECODER_MODE=prompt # prompt|soft REFRAG_SOFT_SCALE=1.0 +LLAMACPP_USE_GPU=1 +LLAMACPP_GPU_LAYERS=32 +LLAMACPP_THREADS=6 +LLAMACPP_GPU_SPLIT= +LLAMACPP_EXTRA_ARGS= # Operational safeguards and timeouts @@ -153,3 +158,4 @@ HYBRID_RESULTS_CACHE=128 HYBRID_RESULTS_CACHE_ENABLED=1 INDEX_CHUNK_LINES=60 INDEX_CHUNK_OVERLAP=10 +USE_GPU_DECODER=1 diff --git a/.env.example b/.env.example index 3d179b8e..f16118f2 100644 --- a/.env.example +++ b/.env.example @@ -108,17 +108,26 @@ REFRAG_ENCODER_MODEL=BAAI/bge-base-en-v1.5 REFRAG_PHI_PATH=/work/models/refrag_phi_768_to_dmodel.json REFRAG_SENSE=heuristic -# Llama.cpp sidecar (optional; REFRAG_RUNTIME=llamacpp) +# Llama.cpp sidecar (optional) +# Docker CPU-only (stable): http://llamacpp:8080 +# Native GPU-accelerated (fast): http://localhost:8081 LLAMACPP_URL=http://llamacpp:8080 REFRAG_DECODER_MODE=prompt # prompt|soft -REFRAG_SOFT_SCALE=1.0 +# GPU Performance Toggle +# Set to 1 to use native GPU-accelerated server on localhost:8081 +# Set to 0 to use Docker CPU-only server (default, stable) +USE_GPU_DECODER=0 -# GLM API provider (alternative to llamacpp; REFRAG_RUNTIME=glm) -GLM_API_KEY= -GLM_API_BASE=https://api.z.ai/api/paas/v4/ -GLM_MODEL=glm-4.6 +REFRAG_SOFT_SCALE=1.0 +# Llama.cpp runtime tuning +LLAMACPP_USE_GPU=0 # Set to 1 to enable Metal/CLBlast acceleration +# LLAMACPP_GPU_LAYERS=-1 # Override number of layers to offload (defaults to -1 when USE_GPU=1) +# LLAMACPP_GPU_SPLIT= # Optional tensor split for multi-GPU setups +# LLAMACPP_THREADS= # Override number of CPU threads +# LLAMACPP_CTX_SIZE=8192 # Context tokens; higher values need more VRAM +# LLAMACPP_EXTRA_ARGS= # Additional flags passed verbatim to llama.cpp # Operational safeguards and timeouts # Limit explosion of micro-chunks on huge files (0 to disable) diff --git a/Dockerfile.llamacpp b/Dockerfile.llamacpp index 5162502b..dabd291a 100644 --- a/Dockerfile.llamacpp +++ b/Dockerfile.llamacpp @@ -19,4 +19,3 @@ RUN mkdir -p /models \ && if [ -n "$MODEL_URL" ]; then echo "Fetching model: $MODEL_URL" && curl -L --fail --retry 3 -C - "$MODEL_URL" -o /models/model.gguf; else echo "No MODEL_URL provided; expecting host volume /models"; fi EXPOSE 8080 ENTRYPOINT ["/app/server", "--model", "/models/model.gguf", "--host", "0.0.0.0", "--port", "8080", "--no-warmup"] - diff --git a/Makefile b/Makefile index d9efed6e..1361b2f0 100644 --- a/Makefile +++ b/Makefile @@ -194,7 +194,6 @@ reset-dev-dual: ## bring up BOTH legacy SSE and Streamable HTTP MCPs (dual-compa docker compose run --rm -e INDEX_MICRO_CHUNKS -e MAX_MICRO_CHUNKS_PER_FILE -e TOKENIZER_PATH -e TOKENIZER_URL indexer --root /work --recreate $(MAKE) llama-model docker compose up -d mcp mcp_indexer mcp_http mcp_indexer_http watcher llamacpp - # Ensure watcher is up even if a prior step or manual bring-up omitted it docker compose up -d watcher docker compose ps @@ -272,4 +271,3 @@ qdrant-prune: qdrant-index-root: python3 scripts/mcp_router.py --run "reindex repo" - diff --git a/README.md b/README.md index 0f0130c6..dd468441 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,61 @@ INDEX_MICRO_CHUNKS=1 MAX_MICRO_CHUNKS_PER_FILE=200 make reset-dev-dual - Ports: 8000/8001 (/sse) and 8002/8003 (/mcp) - Command: `INDEX_MICRO_CHUNKS=1 MAX_MICRO_CHUNKS_PER_FILE=200 make reset-dev-dual` -- You can skip the decoder; it’s feature-flagged off by default. +### Environment Configuration + +**Default Setup:** +- The repository includes `.env.example` with sensible defaults for local development +- On first run, copy it to `.env`: `cp .env.example .env` +- The `make reset-dev*` targets will use your `.env` settings automatically + +**Key Configuration Files:** +- `.env` β€” Your local environment variables (gitignored, safe to customize) +- `.env.example` β€” Template with documented defaults (committed to repo) +- `docker-compose.yml` β€” Service definitions that read from `.env` + +**Recommended Customizations:** + +1. **Enable micro-chunking** (better retrieval quality): + ```bash + INDEX_MICRO_CHUNKS=1 + MAX_MICRO_CHUNKS_PER_FILE=200 + ``` + +2. **Enable decoder for Q&A** (context_answer tool): + ```bash + REFRAG_DECODER=1 # Enable decoder (default: 1) + REFRAG_RUNTIME=llamacpp # Use llama.cpp (default) or glm + ``` + +3. **GPU acceleration** (Apple Silicon Metal): + ```bash + # Option A: Use the toggle script (recommended) + scripts/gpu_toggle.sh gpu + scripts/gpu_toggle.sh start + + # Option B: Manual .env settings + USE_GPU_DECODER=1 + LLAMACPP_URL=http://host.docker.internal:8081 + LLAMACPP_GPU_LAYERS=32 # or -1 for all layers + ``` + +4. **Alternative: GLM API** (instead of local llama.cpp): + ```bash + REFRAG_RUNTIME=glm + GLM_API_KEY=your-api-key-here + GLM_MODEL=glm-4.6 # Optional, defaults to glm-4.6 + ``` + +5. **Custom collection name**: + ```bash + COLLECTION_NAME=my-project # Defaults to auto-detected repo name + ``` + +**After changing `.env`:** +- Restart services: `docker compose restart mcp_indexer mcp_indexer_http` +- For indexing changes: `make reindex` or `make reindex-hard` +- For decoder changes: `docker compose up -d --force-recreate llamacpp` (or restart native server) + ### Switch decoder model (llama.cpp) - Default tiny model: Granite 4.0 Micro (Q4_K_M GGUF) - Change the model by overriding Make vars (downloads to ./models/model.gguf): @@ -59,9 +113,19 @@ INDEX_MICRO_CHUNKS=1 MAX_MICRO_CHUNKS_PER_FILE=200 make reset-dev-dual LLAMACPP_MODEL_URL="https://huggingface.co/ORG/MODEL/resolve/main/model.gguf" \ INDEX_MICRO_CHUNKS=1 MAX_MICRO_CHUNKS_PER_FILE=200 make reset-dev-dual ``` +- Want GPU acceleration? Set `LLAMACPP_USE_GPU=1` (optionally `LLAMACPP_GPU_LAYERS=-1`) in your `.env` before `docker compose up`, or simply run `scripts/gpu_toggle.sh gpu` (described below) to flip the switch for you. - Embeddings: set EMBEDDING_MODEL in .env and reindex (make reindex) +Decoder env toggles (set in `.env` and managed automatically by `scripts/gpu_toggle.sh`): + +| Variable | Description | Typical values | +|-----------------------|-------------------------------------------------------|----------------| +| `USE_GPU_DECODER` | Feature-flag for native Metal decoder | `0` (docker), `1` (native) | +| `LLAMACPP_URL` | Decoder endpoint containers should use | `http://llamacpp:8080` or `http://host.docker.internal:8081` | +| `LLAMACPP_GPU_LAYERS` | Number of layers to offload to GPU (`-1` = all) | `0`, `32`, `-1` | + + Alternative (compose only) ```bash HOST_INDEX_PATH="$(pwd)" FASTMCP_INDEXER_PORT=8001 docker compose up -d qdrant mcp mcp_indexer indexer watcher @@ -73,6 +137,28 @@ HOST_INDEX_PATH="$(pwd)" FASTMCP_INDEXER_PORT=8001 docker compose up -d qdrant m 3. Confirm collection health with `make qdrant-status` (calls the MCP router to print counts and timestamps). 4. Iterate using search helpers such as `make hybrid ARGS="--query 'async file watcher'"` or invoke the MCP tools directly from your client. +### Apple Silicon Metal GPU (native) vs Docker decoder + +On Apple Silicon you can run the llama.cpp decoder natively with Metal while keeping the rest of the stack in Docker: + +1. Install the Metal-enabled llama.cpp binary (e.g. `brew install llama.cpp`). +2. Flip to GPU mode and start the native server: + ```bash + scripts/gpu_toggle.sh gpu + scripts/gpu_toggle.sh start # launches llama-server on localhost:8081 + docker compose up -d --force-recreate mcp_indexer mcp_indexer_http + docker compose stop llamacpp # optional once the native server is healthy + ``` + The toggle updates `.env` to point at `http://host.docker.internal:8081` so containers reach the host process. +3. Run `scripts/gpu_toggle.sh status` to confirm the native server is healthy. All MCP `context_answer` calls will now use the Metal-backed decoder. + +Want the original dockerised decoder (CPU-only or x86 GPU fallback)? Swap back with: +```bash +scripts/gpu_toggle.sh docker +docker compose up -d --force-recreate mcp_indexer mcp_indexer_http llamacpp +``` +This re-enables the `llamacpp` container and resets `.env` to `http://llamacpp:8080`. + ### Make targets (quick reference) - reset-dev: SSE stack on 8000/8001; seeds Qdrant, downloads tokenizer + tiny llama.cpp model, reindexes, brings up memory + indexer + watcher - reset-dev-codex: RMCP stack on 8002/8003; same seeding + bring-up for Codex/Qodo @@ -166,6 +252,25 @@ Reranker - RERANKER_ENABLED: 1/true to enable, 0/false to disable; default is enabled in server - Timeouts/failures automatically fall back to hybrid results +Decoder (llama.cpp / GLM) +- REFRAG_DECODER: 1 to enable decoder for context_answer; 0 to disable (default: 1) +- REFRAG_RUNTIME: llamacpp or glm (default: llamacpp) +- LLAMACPP_URL: llama.cpp server endpoint (default: http://llamacpp:8080 or http://host.docker.internal:8081 for GPU) +- LLAMACPP_TIMEOUT_SEC: Decoder request timeout in seconds (default: 300) +- DECODER_MAX_TOKENS: Max tokens for decoder responses (default: 4000) +- REFRAG_DECODER_MODE: prompt or soft (default: prompt; soft requires patched llama.cpp) +- GLM_API_KEY: API key for GLM provider (required when REFRAG_RUNTIME=glm) +- GLM_MODEL: GLM model name (default: glm-4.6) +- USE_GPU_DECODER: 1 for native Metal decoder on host, 0 for Docker (managed by gpu_toggle.sh) +- LLAMACPP_GPU_LAYERS: Number of layers to offload to GPU, -1 for all (default: 32) + +ReFRAG (micro-chunking and retrieval) +- REFRAG_MODE: 1 to enable micro-chunking and span budgeting (default: 1) +- REFRAG_GATE_FIRST: 1 to enable mini-vector gating before dense search (default: 1) +- REFRAG_CANDIDATES: Number of candidates for gate-first filtering (default: 200) +- MICRO_BUDGET_TOKENS: Global token budget for context_answer spans (default: 512) +- MICRO_OUT_MAX_SPANS: Max number of spans to return per query (default: 3) + Ports - FASTMCP_PORT (SSE/RMCP): Override Memory MCP ports (defaults: 8000/8002) - FASTMCP_INDEXER_PORT (SSE/RMCP): Override Indexer MCP ports (defaults: 8001/8003) @@ -201,6 +306,18 @@ Ports | FASTMCP_INDEXER_HTTP_PORT | Indexer RMCP host port mapping | 8003 | | FASTMCP_HEALTH_PORT | Health port (memory/indexer) | memory: 18000; indexer: 18001 | | LLM_EXPAND_MAX | Max alternate queries generated via LLM | 0 | +| REFRAG_DECODER | Enable decoder for context_answer | 1 (enabled) | +| REFRAG_RUNTIME | Decoder backend: llamacpp or glm | llamacpp | +| LLAMACPP_URL | llama.cpp server endpoint | http://llamacpp:8080 or http://host.docker.internal:8081 | +| LLAMACPP_TIMEOUT_SEC | Decoder request timeout | 300 | +| DECODER_MAX_TOKENS | Max tokens for decoder responses | 4000 | +| GLM_API_KEY | API key for GLM provider | unset | +| GLM_MODEL | GLM model name | glm-4.6 | +| USE_GPU_DECODER | Native Metal decoder (1) vs Docker (0) | 0 (docker) | +| REFRAG_MODE | Enable micro-chunking and span budgeting | 1 (enabled) | +| REFRAG_GATE_FIRST | Enable mini-vector gating | 1 (enabled) | +| REFRAG_CANDIDATES | Candidates for gate-first filtering | 200 | +| MICRO_BUDGET_TOKENS | Token budget for context_answer | 512 | ## Running tests @@ -317,7 +434,7 @@ Notes: - SSE β€œInvalid session ID” when POSTing /messages directly: - Expected if you didn’t initiate an SSE session first. Use an MCP client (e.g., mcp-remote) to handle the handshake. - llama.cpp platform warning on Apple Silicon: - - Safe to ignore for local dev, or set platform: linux/amd64 for the service, or build a native image. + - Prefer the native path above (`scripts/gpu_toggle.sh gpu`). If you stick with Docker, add `platform: linux/amd64` to the service or ignore the warning during local dev. - Indexing feels stuck on very large files: - Use MAX_MICRO_CHUNKS_PER_FILE=200 during dev runs. @@ -400,13 +517,18 @@ Memory MCP (8000 SSE, 8002 RMCP): Indexer/Search MCP (8001 SSE, 8003 RMCP): - repo_search β€” hybrid code search (dense + lexical + optional reranker) - context_search β€” search that can also blend memory results (include_memories) +- context_answer β€” natural-language Q&A with retrieval + local LLM (llama.cpp or GLM) - code_search β€” alias of repo_search - repo_search_compat β€” permissive wrapper that normalizes q/text/queries/top_k payloads +- context_answer_compat β€” permissive wrapper for context_answer with lenient argument handling +- expand_query(query, max_new?) β€” LLM-assisted query expansion (generates 1-2 alternates) - qdrant_index_root β€” index /work (mounted repo root) with safe defaults - qdrant_index(subdir?, recreate?, collection?) β€” index a subdir or recreate collection - qdrant_prune β€” remove points for missing files or file_hash mismatch - qdrant_list β€” list Qdrant collections - qdrant_status β€” collection counts and recent ingestion timestamps +- workspace_info(workspace_path?) β€” read .codebase/state.json and resolve default collection +- list_workspaces(search_root?) β€” scan for multiple workspaces in multi-repo environments - memory_store β€” convenience memory store from the indexer (uses default collection) - search_tests_for β€” intent wrapper for test files - search_config_for β€” intent wrapper for likely config files @@ -417,6 +539,7 @@ Indexer/Search MCP (8001 SSE, 8003 RMCP): Notes: - Most search tools accept filters like language, under, path_glob, kind, symbol, ext. - Reranker enabled by default; timeouts fall back to hybrid results. +- context_answer requires decoder enabled (REFRAG_DECODER=1) with llama.cpp or GLM backend. ### Qodo Integration (RMCP config) diff --git a/docker-compose.arm64.yml b/docker-compose.arm64.yml new file mode 100644 index 00000000..90a33235 --- /dev/null +++ b/docker-compose.arm64.yml @@ -0,0 +1,5 @@ +services: + qdrant: + platform: linux/arm64/v8 + llamacpp: + platform: linux/arm64/v8 diff --git a/docker-compose.yml b/docker-compose.yml index 11e8a63c..ccb39f5c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -57,6 +57,7 @@ services: - DEBUG_CONTEXT_ANSWER=${DEBUG_CONTEXT_ANSWER:-1} - REFRAG_DECODER=${REFRAG_DECODER:-1} - LLAMACPP_URL=${LLAMACPP_URL:-http://llamacpp:8080} + - USE_GPU_DECODER=${USE_GPU_DECODER:-0} - LLAMACPP_TIMEOUT_SEC=${LLAMACPP_TIMEOUT_SEC:-180} - CTX_REQUIRE_IDENTIFIER=${CTX_REQUIRE_IDENTIFIER:-0} - COLLECTION_NAME=${COLLECTION_NAME:-my-collection} @@ -115,6 +116,7 @@ services: - DEBUG_CONTEXT_ANSWER=${DEBUG_CONTEXT_ANSWER:-1} - REFRAG_DECODER=${REFRAG_DECODER:-1} - LLAMACPP_URL=${LLAMACPP_URL:-http://llamacpp:8080} + - USE_GPU_DECODER=${USE_GPU_DECODER:-0} - LLAMACPP_TIMEOUT_SEC=${LLAMACPP_TIMEOUT_SEC:-180} - CTX_REQUIRE_IDENTIFIER=${CTX_REQUIRE_IDENTIFIER:-0} - COLLECTION_NAME=${COLLECTION_NAME:-my-collection} @@ -135,17 +137,44 @@ services: # Optional sidecar providing a text-generation API on :8080 # No behavior change unless REFRAG_DECODER=1 environment: - - LLAMA_ARG_MODEL=/models/model.gguf - - LLAMA_ARG_CTX_SIZE=8192 - - LLAMA_ARG_HOST=0.0.0.0 - - LLAMA_ARG_PORT=8080 + - LLAMACPP_CTX_SIZE=${LLAMACPP_CTX_SIZE:-8192} + - LLAMACPP_HOST=0.0.0.0 + - LLAMACPP_PORT=8080 + - LLAMACPP_USE_GPU=${LLAMACPP_USE_GPU:-0} + - LLAMACPP_GPU_LAYERS=${LLAMACPP_GPU_LAYERS:-0} + - LLAMACPP_GPU_SPLIT=${LLAMACPP_GPU_SPLIT:-} + - LLAMACPP_THREADS=${LLAMACPP_THREADS:-} + - LLAMACPP_EXTRA_ARGS=${LLAMACPP_EXTRA_ARGS:-} + - LLAMACPP_NO_WARMUP=${LLAMACPP_NO_WARMUP:-1} + - LLAMACPP_TEMPERATURE=${LLAMACPP_TEMPERATURE:-} ports: - "8080:8080" volumes: - ./models:/models:ro - # The server image's entrypoint is already the server binary; pass args only - command: ["--model", "/models/model.gguf", "--host", "0.0.0.0", "--port", "8080", "--no-warmup"] - + entrypoint: ["/bin/sh","-lc"] + command: + - | + set -e + ARGS="--model /models/model.gguf --host ${LLAMACPP_HOST:-0.0.0.0} --port ${LLAMACPP_PORT:-8080} --ctx-size ${LLAMACPP_CTX_SIZE:-8192}" + if [ "${LLAMACPP_USE_GPU:-0}" = "1" ]; then + LAYERS="${LLAMACPP_GPU_LAYERS:--1}" + else + LAYERS="${LLAMACPP_GPU_LAYERS:-0}" + fi + ARGS="$$ARGS --n-gpu-layers $${LAYERS}" + if [ -n "${LLAMACPP_GPU_SPLIT:-}" ]; then + ARGS="$$ARGS --tensor-split ${LLAMACPP_GPU_SPLIT}" + fi + if [ "${LLAMACPP_NO_WARMUP:-1}" != "0" ]; then + ARGS="$$ARGS --no-warmup" + fi + if [ -n "${LLAMACPP_THREADS:-}" ]; then + ARGS="$$ARGS --threads ${LLAMACPP_THREADS}" + fi + if [ -n "${LLAMACPP_EXTRA_ARGS:-}" ]; then + ARGS="$$ARGS ${LLAMACPP_EXTRA_ARGS}" + fi + exec /app/llama-server $$ARGS indexer: build: diff --git a/mcp-proxy/mcp-indexer-proxy.py b/mcp-proxy/mcp-indexer-proxy.py new file mode 100644 index 00000000..5a32910a --- /dev/null +++ b/mcp-proxy/mcp-indexer-proxy.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +""" +MCP proxy script for connecting Zed to the Context-Engine indexer MCP server. +This script bridges stdio communication (expected by Zed) with HTTP communication. +""" + +import asyncio +import json +import sys +import aiohttp +import uuid + + +async def main(): + """Main proxy function that bridges stdio and HTTP MCP communication.""" + session_id = str(uuid.uuid4()) + base_url = "http://localhost:8003" + + async with aiohttp.ClientSession() as session: + # Start SSE connection + async with session.get(f"{base_url}/sse") as response: + if response.status != 200: + print( + f"Failed to connect to MCP server \ No newline at end of file diff --git a/scripts/gpu_toggle.sh b/scripts/gpu_toggle.sh new file mode 100755 index 00000000..f6bcc514 --- /dev/null +++ b/scripts/gpu_toggle.sh @@ -0,0 +1,232 @@ +#!/bin/bash +# GPU Decoder Toggle Script +# Manages switching between Docker CPU-only and native GPU-accelerated decoders + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" +ENV_FILE="$PROJECT_ROOT/.env" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +usage() { + echo "Usage: $0 [COMMAND]" + echo "" + echo "Commands:" + echo " status Show current decoder configuration" + echo " docker Switch to Docker CPU-only decoder (stable)" + echo " gpu Switch to native GPU-accelerated decoder (fast)" + echo " start Start native GPU decoder on port 8081" + echo " stop Stop native GPU decoder" + echo " test Test current decoder configuration" + echo " help Show this help message" + echo "" + echo "Examples:" + echo " $0 status # Check current setup" + echo " $0 gpu && $0 start # Switch to GPU and start native server" + echo " $0 docker # Switch back to Docker CPU-only" +} + +get_env_value() { + local key="$1" + local default="${2:-}" + + if [[ -f "$ENV_FILE" ]]; then + grep "^${key}=" "$ENV_FILE" 2>/dev/null | cut -d'=' -f2- | tr -d '"' || echo "$default" + else + echo "$default" + fi +} + +set_env_value() { + local key="$1" + local value="$2" + + if [[ -f "$ENV_FILE" ]]; then + # Update existing value or add new one + if grep -q "^${key}=" "$ENV_FILE"; then + sed -i.bak "s/^${key}=.*/${key}=${value}/" "$ENV_FILE" + else + echo "${key}=${value}" >> "$ENV_FILE" + fi + else + echo "${key}=${value}" > "$ENV_FILE" + fi + echo -e "${GREEN}Set ${key}=${value}${NC}" +} + +show_status() { + echo -e "${BLUE}Current Decoder Configuration${NC}" + echo "" + + local use_gpu=$(get_env_value "USE_GPU_DECODER" "0") + local llamacpp_url=$(get_env_value "LLAMACPP_URL" "http://llamacpp:8080") + + if [[ "$use_gpu" == "1" ]]; then + echo -e "Mode: ${GREEN}GPU-Accelerated${NC} (native llama.cpp with Metal)" + echo -e "URL: ${GREEN}http://host.docker.internal:8081${NC}" + echo "" + + # Check if native server is running + if curl -s -o /dev/null -w "%{http_code}" http://localhost:8081/health 2>/dev/null | grep -q "200"; then + echo -e "Status: ${GREEN}Native GPU server is running${NC}" + else + echo -e "Status: ${RED}Native GPU server is not running${NC}" + echo -e " ${YELLOW}Run: $0 start${NC}" + fi + else + echo -e "Mode: ${YELLOW}Docker CPU-Only${NC} (stable, containerized)" + echo -e "URL: ${YELLOW}${llamacpp_url}${NC}" + echo "" + + # Check if Docker container is running + if docker ps --format '{{.Names}}' | grep -q "llama-decoder"; then + echo -e "Status: ${GREEN}Docker container is running${NC}" + else + echo -e "Status: ${RED}Docker container is not running${NC}" + echo -e " ${YELLOW}Run: docker compose up llamacpp -d${NC}" + fi + fi +} + +switch_to_docker() { + echo -e "${BLUE}Switching to Docker CPU-only decoder${NC}" + set_env_value "USE_GPU_DECODER" "0" + echo "" + echo -e "${GREEN}Switched to Docker mode${NC}" + echo -e " - Stable and containerized" + echo -e " - CPU-only inference" + echo -e " - Uses Docker service: llamacpp:8080" + echo "" + echo -e "${YELLOW}Restart your indexer services to apply changes:${NC}" + echo -e " docker compose restart mcp_indexer mcp_indexer_http" +} + +switch_to_gpu() { + echo -e "${BLUE}Switching to native GPU-accelerated decoder${NC}" + set_env_value "USE_GPU_DECODER" "1" + echo "" + echo -e "${GREEN}Switched to GPU mode${NC}" + echo -e " - Metal GPU acceleration (Apple Silicon)" + echo -e " - Significantly faster inference" + echo -e " - Uses native server: localhost:8081" + echo "" + echo -e "${YELLOW}Next steps:${NC}" + echo -e " 1. Start native server: $0 start" + echo -e " 2. Restart indexer: docker compose restart mcp_indexer mcp_indexer_http" +} + +start_native_server() { + echo -e "${BLUE}Starting native GPU-accelerated decoder${NC}" + + # Check if already running + if curl -s -o /dev/null -w "%{http_code}" http://localhost:8081/health 2>/dev/null | grep -q "200"; then + echo -e "${YELLOW}Native server is already running on port 8081${NC}" + return 0 + fi + + # Check if llama-server is available + if ! command -v llama-server &> /dev/null; then + echo -e "${RED}llama-server not found${NC}" + echo -e " Install with: brew install llama.cpp" + return 1 + fi + + # Check if model exists + local model_path="$PROJECT_ROOT/models/model.gguf" + if [[ ! -f "$model_path" ]]; then + echo -e "${RED}Model not found: $model_path${NC}" + return 1 + fi + + echo -e "${GREEN}Starting native llama-server with GPU acceleration...${NC}" + echo -e " Model: $model_path" + echo -e " GPU Layers: 32" + echo -e " Port: 8081" + echo "" + + # Start in background + nohup llama-server \ + --model "$model_path" \ + --host 0.0.0.0 \ + --port 8081 \ + --n-gpu-layers 32 \ + --ctx-size 8192 \ + --no-warmup \ + > "$PROJECT_ROOT/llamacpp-gpu.log" 2>&1 & + + local pid=$! + echo -e "${GREEN}Started native server (PID: $pid)${NC}" + echo -e " Logs: $PROJECT_ROOT/llamacpp-gpu.log" + echo -e " Health: http://localhost:8081/health" + + # Wait a moment and check if it started successfully + sleep 3 + if curl -s -o /dev/null -w "%{http_code}" http://localhost:8081/health 2>/dev/null | grep -q "200"; then + echo -e "${GREEN}Server is healthy and ready${NC}" + else + echo -e "${YELLOW}Server may still be starting up. Check logs if issues persist.${NC}" + fi +} + +stop_native_server() { + echo -e "${BLUE}Stopping native GPU decoder${NC}" + + # Find and kill llama-server processes on port 8081 + local pids=$(lsof -ti:8081 2>/dev/null || true) + if [[ -n "$pids" ]]; then + echo -e "${GREEN}Stopping processes: $pids${NC}" + kill $pids + sleep 2 + + # Force kill if still running + local remaining=$(lsof -ti:8081 2>/dev/null || true) + if [[ -n "$remaining" ]]; then + echo -e "${YELLOW}Force killing remaining processes: $remaining${NC}" + kill -9 $remaining + fi + + echo -e "${GREEN}Native server stopped${NC}" + else + echo -e "${YELLOW}No native server found running on port 8081${NC}" + fi +} + +test_decoder() { + echo -e "${BLUE}Testing current decoder configuration${NC}" + echo "" + + cd "$PROJECT_ROOT" + python test_gpu_switch.py +} + +# Main command handling +case "${1:-help}" in + "status") + show_status + ;; + "docker") + switch_to_docker + ;; + "gpu") + switch_to_gpu + ;; + "start") + start_native_server + ;; + "stop") + stop_native_server + ;; + "test") + test_decoder + ;; + "help"|*) + usage + ;; +esac diff --git a/scripts/mcp_indexer_server.py b/scripts/mcp_indexer_server.py index 8695ebca..30739686 100644 --- a/scripts/mcp_indexer_server.py +++ b/scripts/mcp_indexer_server.py @@ -23,6 +23,7 @@ Note: We use the fastmcp library for quick SSE hosting. If you change to another MCP server framework, keep the tool names and args stable. """ + from __future__ import annotations import json import asyncio @@ -72,12 +73,15 @@ safe_float, safe_bool, ) + logger = get_logger(__name__) except ImportError: # Fallback if logger module not available import logging + logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) + # Define fallback safe conversion functions def safe_int(value, default=0, logger=None, context=""): try: @@ -86,6 +90,7 @@ def safe_int(value, default=0, logger=None, context=""): return int(value) except (ValueError, TypeError): return default + def safe_float(value, default=0.0, logger=None, context=""): try: if value is None or (isinstance(value, str) and value.strip() == ""): @@ -93,6 +98,7 @@ def safe_float(value, default=0.0, logger=None, context=""): return float(value) except (ValueError, TypeError): return default + def safe_bool(value, default=False, logger=None, context=""): try: if value is None or (isinstance(value, str) and value.strip() == ""): @@ -108,22 +114,32 @@ def safe_bool(value, default=False, logger=None, context=""): except (ValueError, TypeError): return default + +# Global lock to guard temporary env toggles used during ReFRAG retrieval/decoding +_ENV_LOCK = threading.Lock() + # Shared utilities (lex hashing, snippet highlighter) try: from scripts.utils import highlight_snippet as _do_highlight_snippet -except Exception: +except Exception as e: + logger.warning(f"Failed to import rich for syntax highlighting: {e}") _do_highlight_snippet = None # fallback guarded at call site # Back-compat shim for tests expecting _highlight_snippet in this module # Delegates to scripts.utils.highlight_snippet when available try: + def _highlight_snippet(snippet, tokens): # type: ignore - return _do_highlight_snippet(snippet, tokens) if _do_highlight_snippet else snippet + return ( + _do_highlight_snippet(snippet, tokens) if _do_highlight_snippet else snippet + ) except Exception: + def _highlight_snippet(snippet, tokens): # type: ignore return snippet + try: # Official MCP Python SDK (FastMCP convenience server) from mcp.server.fastmcp import FastMCP @@ -132,10 +148,12 @@ def _highlight_snippet(snippet, tokens): # type: ignore APP_NAME = os.environ.get("FASTMCP_SERVER_NAME", "qdrant-indexer-mcp") HOST = os.environ.get("FASTMCP_HOST", "0.0.0.0") -PORT = safe_int(os.environ.get("FASTMCP_INDEXER_PORT", "8001"), default=8001, logger=logger, context="FASTMCP_INDEXER_PORT") - -# Process-wide lock to guard environment mutations during retrieval (gate-first/budgeting) -_ENV_LOCK = threading.RLock() +PORT = safe_int( + os.environ.get("FASTMCP_INDEXER_PORT", "8001"), + default=8001, + logger=logger, + context="FASTMCP_INDEXER_PORT", +) # Context manager to temporarily override environment variables safely @@ -155,6 +173,7 @@ def _env_overrides(pairs: Dict[str, str]): else: os.environ[k] = old + # Set default environment variables for context_answer functionality # These are set in docker-compose.yml but provide fallbacks for local dev @@ -165,6 +184,7 @@ def _primary_identifier_from_queries(qs: list[str]) -> str: """ try: import re as _re + cand: list[str] = [] for q in qs: for t in _re.findall(r"[A-Za-z_][A-Za-z0-9_]*", q or ""): @@ -173,7 +193,9 @@ def _primary_identifier_from_queries(qs: list[str]) -> str: # Accept: ALL_CAPS, has_underscore, camelCase (mixed case), or longer lowercase is_all_caps = t.isupper() has_underscore = "_" in t - is_camel = any(c.isupper() for c in t[1:]) and any(c.islower() for c in t) + is_camel = any(c.isupper() for c in t[1:]) and any( + c.islower() for c in t + ) is_longer_lower = t.islower() and len(t) >= 3 if is_all_caps or has_underscore or is_camel or is_longer_lower: @@ -189,35 +211,57 @@ def _score(token: str) -> int: if token.isupper(): score += 100 # ALL_CAPS highest priority if "_" in token: - score += 50 # snake_case + score += 50 # snake_case if any(c.isupper() for c in token[1:]) and any(c.islower() for c in token): - score += 75 # camelCase + score += 75 # camelCase score += len(token) # Longer is slightly better return score cand.sort(key=_score, reverse=True) return cand[0] if cand else "" - except Exception: + except Exception as e: + logger.debug(f"Primary identifier extraction failed: {e}") return "" + QDRANT_URL = os.environ.get("QDRANT_URL", "http://qdrant:6333") DEFAULT_COLLECTION = os.environ.get("COLLECTION_NAME", "my-collection") -MAX_LOG_TAIL = safe_int(os.environ.get("MCP_MAX_LOG_TAIL", "4000"), default=4000, logger=logger, context="MCP_MAX_LOG_TAIL") -SNIPPET_MAX_BYTES = safe_int(os.environ.get("MCP_SNIPPET_MAX_BYTES", "8192"), default=8192, logger=logger, context="MCP_SNIPPET_MAX_BYTES") +MAX_LOG_TAIL = safe_int( + os.environ.get("MCP_MAX_LOG_TAIL", "4000"), + default=4000, + logger=logger, + context="MCP_MAX_LOG_TAIL", +) +SNIPPET_MAX_BYTES = safe_int( + os.environ.get("MCP_SNIPPET_MAX_BYTES", "8192"), + default=8192, + logger=logger, + context="MCP_SNIPPET_MAX_BYTES", +) -MCP_TOOL_TIMEOUT_SECS = safe_float(os.environ.get("MCP_TOOL_TIMEOUT_SECS", "3600"), default=3600.0, logger=logger, context="MCP_TOOL_TIMEOUT_SECS") +MCP_TOOL_TIMEOUT_SECS = safe_float( + os.environ.get("MCP_TOOL_TIMEOUT_SECS", "3600"), + default=3600.0, + logger=logger, + context="MCP_TOOL_TIMEOUT_SECS", +) # Set default environment variables for context_answer functionality os.environ.setdefault("DEBUG_CONTEXT_ANSWER", "1") os.environ.setdefault("REFRAG_DECODER", "1") os.environ.setdefault("LLAMACPP_URL", "http://localhost:8080") -os.environ.setdefault("CTX_REQUIRE_IDENTIFIER", "0") # Disable strict identifier requirement +os.environ.setdefault("USE_GPU_DECODER", "0") +os.environ.setdefault( + "CTX_REQUIRE_IDENTIFIER", "0" +) # Disable strict identifier requirement + # --- Workspace state integration helpers --- def _state_file_path(ws_path: str = "/work") -> str: try: return os.path.join(ws_path, ".codebase", "state.json") - except Exception: + except Exception as e: + logger.warning(f"State file path construction failed, using fallback: {e}") return "/work/.codebase/state.json" @@ -229,7 +273,8 @@ def _read_ws_state(ws_path: str = "/work") -> Optional[Dict[str, Any]]: with open(p, "r", encoding="utf-8") as f: obj = json.load(f) return obj if isinstance(obj, dict) else None - except Exception: + except Exception as e: + logger.debug(f"Failed to read workspace state: {e}") return None @@ -244,16 +289,16 @@ def _default_collection() -> str: return os.environ.get("COLLECTION_NAME", DEFAULT_COLLECTION) - def _work_script(name: str) -> str: """Return path to a script under /work if present, else local ./scripts. Keeps Docker/default behavior but works in local dev without /work mount. """ try: - w = os.path.join("/work", "scripts", name) - if os.path.exists(w): - return w - except Exception: + p = os.path.join("/work", "scripts", name) + if os.path.exists(p): + return p + except Exception as e: + logger.debug(f"Failed to locate script {name}: {e}") pass return os.path.join(os.getcwd(), "scripts", name) @@ -261,6 +306,7 @@ def _work_script(name: str) -> str: # Invalidate router scratchpad after reindex to avoid stale state reuse _def_ws = "/work" + def _invalidate_router_scratchpad(ws_path: str = _def_ws) -> bool: try: p = os.path.join(ws_path, ".codebase", "router_scratchpad.json") @@ -279,29 +325,83 @@ def _invalidate_router_scratchpad(ws_path: str = _def_ws) -> bool: _TOOLS_REGISTRY: list[dict] = [] try: _orig_tool = mcp.tool + def _tool_capture_wrapper(*dargs, **dkwargs): orig_deco = _orig_tool(*dargs, **dkwargs) + def _inner(fn): try: - _TOOLS_REGISTRY.append({ - "name": dkwargs.get("name") or getattr(fn, "__name__", ""), - "description": (getattr(fn, "__doc__", None) or "").strip(), - }) + _TOOLS_REGISTRY.append( + { + "name": dkwargs.get("name") or getattr(fn, "__name__", ""), + "description": (getattr(fn, "__doc__", None) or "").strip(), + } + ) except (AttributeError, TypeError) as e: logger.warning(f"Failed to capture tool metadata for {fn}", exc_info=e) return orig_deco(fn) + return _inner + mcp.tool = _tool_capture_wrapper # type: ignore except (AttributeError, TypeError) as e: logger.warning("Failed to wrap mcp.tool decorator", exc_info=e) + +def _relax_var_kwarg_defaults() -> None: + """Allow tools that rely on **kwargs compatibility shims to be invoked without + callers supplying an explicit 'kwargs' or 'arguments' field.""" + try: + from pydantic_core import PydanticUndefined as _PydanticUndefined # type: ignore + except Exception: # pragma: no cover - defensive + + class _Sentinel: # type: ignore + pass + + _PydanticUndefined = _Sentinel() # type: ignore + + try: + tool_manager = getattr(mcp, "_tool_manager", None) + tools = getattr(tool_manager, "_tools", {}) if tool_manager is not None else {} + except Exception: + tools = {} + + for tool in tools.values(): + try: + model = getattr(tool.fn_metadata, "arg_model", None) + if model is None: + continue + fields = getattr(model, "model_fields", {}) + changed = False + for key in ("kwargs", "arguments"): + fld = fields.get(key) + if fld is None: + continue + default = getattr(fld, "default", None) + default_factory = getattr(fld, "default_factory", None) + if default is _PydanticUndefined and default_factory is None: + try: + fld.default_factory = dict # type: ignore[attr-defined] + except Exception: + fld.default_factory = lambda: {} # type: ignore + fld.default = None + changed = True + if changed: + try: + model.model_rebuild(force=True) + except Exception: + pass + except Exception: + continue + + # Lightweight readiness endpoint on a separate health port (non-MCP), optional # Exposes GET /readyz returning {ok: true, app: } once process is up. HEALTH_PORT = safe_int( os.environ.get("FASTMCP_HEALTH_PORT", "18001"), default=18001, logger=logger, - context="FASTMCP_HEALTH_PORT" + context="FASTMCP_HEALTH_PORT", ) @@ -330,7 +430,11 @@ def do_GET(self): is_decoder_enabled = lambda: False # type: ignore try: if not is_decoder_enabled(): - tools = [t for t in tools if (t.get("name") or "") != "expand_query"] + tools = [ + t + for t in tools + if (t.get("name") or "") != "expand_query" + ] except Exception: pass payload = {"ok": True, "tools": tools} @@ -345,8 +449,6 @@ def do_GET(self): except Exception: pass - - def log_message(self, *args, **kwargs): # Quiet health server logs return @@ -366,7 +468,11 @@ def log_message(self, *args, **kwargs): # Fallback if subprocess_manager not available logger.warning("subprocess_manager not available, using fallback implementation") - async def run_subprocess_async(cmd: List[str], timeout: Optional[float] = None, env: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + async def run_subprocess_async( + cmd: List[str], + timeout: Optional[float] = None, + env: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: """Fallback subprocess runner if subprocess_manager is not available.""" proc: Optional[asyncio.subprocess.Process] = None try: @@ -429,9 +535,12 @@ def _cap_tail(s: str) -> str: except Exception: pass + # Async subprocess runner to avoid blocking event loop async def _run_async( - cmd: list[str], env: Optional[Dict[str, str]] = None, timeout: Optional[float] = None + cmd: list[str], + env: Optional[Dict[str, str]] = None, + timeout: Optional[float] = None, ) -> Dict[str, Any]: """Run subprocess with proper resource management using SubprocessManager.""" # Default timeout from env if not provided by caller @@ -518,7 +627,8 @@ def _parse_kv_string(s: str) -> _Dict[str, _Any]: k, v = part.split("=", 1) out[k.strip()] = _coerce_value_string(v.strip()) return out - except Exception: + except Exception as e: + logger.debug(f"Failed to parse KV string '{input_str}': {e}") return {} return out @@ -570,6 +680,14 @@ def _to_str_list_relaxed(x: _Any) -> list[str]: def _extract_kwargs_payload(kwargs: _Any) -> _Dict[str, _Any]: try: + # Handle kwargs being passed as a string "{}" by some MCP clients + if isinstance(kwargs, str): + parsed = _maybe_parse_jsonish(kwargs) + if isinstance(parsed, dict): + kwargs = parsed + else: + return {} + if isinstance(kwargs, dict) and "kwargs" in kwargs: inner = kwargs.get("kwargs") if isinstance(inner, dict): @@ -719,7 +837,10 @@ async def qdrant_index_root( coll = _c else: try: - from scripts.workspace_state import get_collection_name as _ws_get_collection_name # type: ignore + from scripts.workspace_state import ( + get_collection_name as _ws_get_collection_name, + ) # type: ignore + coll = _ws_get_collection_name("/work") except Exception: coll = _default_collection() @@ -744,7 +865,7 @@ async def qdrant_index_root( @mcp.tool() -async def qdrant_list(**kwargs) -> Dict[str, Any]: +async def qdrant_list(kwargs: Any = None) -> Dict[str, Any]: """List available Qdrant collections. When to use: @@ -773,9 +894,10 @@ async def qdrant_list(**kwargs) -> Dict[str, Any]: return {"error": str(e)} - @mcp.tool() -async def workspace_info(workspace_path: Optional[str] = None, **kwargs) -> Dict[str, Any]: +async def workspace_info( + workspace_path: Optional[str] = None, kwargs: Any = None +) -> Dict[str, Any]: """Read .codebase/state.json for the current workspace and resolve defaults. When to use: @@ -790,7 +912,11 @@ async def workspace_info(workspace_path: Optional[str] = None, **kwargs) -> Dict """ ws_path = (workspace_path or "/work").strip() or "/work" st = _read_ws_state(ws_path) or {} - coll = (st.get("qdrant_collection") if isinstance(st, dict) else None) or os.environ.get("COLLECTION_NAME") or DEFAULT_COLLECTION + coll = ( + (st.get("qdrant_collection") if isinstance(st, dict) else None) + or os.environ.get("COLLECTION_NAME") + or DEFAULT_COLLECTION + ) return { "workspace_path": ws_path, "default_collection": coll, @@ -798,6 +924,7 @@ async def workspace_info(workspace_path: Optional[str] = None, **kwargs) -> Dict "state": st or {}, } + @mcp.tool() async def list_workspaces(search_root: Optional[str] = None) -> Dict[str, Any]: """Scan search_root recursively for .codebase/state.json and summarize workspaces. @@ -813,6 +940,7 @@ async def list_workspaces(search_root: Optional[str] = None) -> Dict[str, Any]: """ try: from scripts.workspace_state import list_workspaces as _lw # type: ignore + items = await asyncio.to_thread(lambda: _lw(search_root)) return {"workspaces": items} except Exception as e: @@ -951,7 +1079,7 @@ async def qdrant_status( collection: Optional[str] = None, max_points: Optional[int] = None, batch: Optional[int] = None, - **kwargs, + kwargs: Any = None, ) -> Dict[str, Any]: """Summarize collection size and recent index timestamps. @@ -975,8 +1103,6 @@ async def qdrant_status( if _extra and not collection: collection = _extra.get("collection", collection) if _extra and max_points in (None, "") and _extra.get("max_points") is not None: - - max_points = _coerce_int(_extra.get("max_points"), None) if _extra and batch in (None, "") and _extra.get("batch") is not None: batch = _coerce_int(_extra.get("batch"), None) @@ -1092,8 +1218,6 @@ async def qdrant_index( if _looks_jsonish_string(collection): _parsed = _maybe_parse_jsonish(collection) if isinstance(_parsed, dict): - - subdir = _parsed.get("subdir", subdir) collection = _parsed.get("collection", collection) if recreate is None and "recreate" in _parsed: @@ -1127,7 +1251,10 @@ async def qdrant_index( coll = _c2 else: try: - from scripts.workspace_state import get_collection_name as _ws_get_collection_name # type: ignore + from scripts.workspace_state import ( + get_collection_name as _ws_get_collection_name, + ) # type: ignore + coll = _ws_get_collection_name("/work") except Exception: coll = _default_collection() @@ -1157,7 +1284,7 @@ async def qdrant_index( @mcp.tool() -async def qdrant_prune(**kwargs) -> Dict[str, Any]: +async def qdrant_prune(kwargs: Any = None) -> Dict[str, Any]: """Remove stale points for /work (files deleted/moved but still in the index). When to use: @@ -1180,9 +1307,9 @@ async def qdrant_prune(**kwargs) -> Dict[str, Any]: @mcp.tool() async def repo_search( query: Any = None, + queries: Any = None, # Alias for query (many clients use this) limit: Any = None, per_path: Any = None, - include_snippet: Any = None, context_lines: Any = None, rerank_enabled: Any = None, @@ -1206,7 +1333,7 @@ async def repo_search( case: Any = None, # Response shaping compact: Any = None, - **kwargs, + kwargs: Any = None, ) -> Dict[str, Any]: """Zero-config code search over repositories (hybrid: vector + lexical RRF, optional rerank). @@ -1233,13 +1360,17 @@ async def repo_search( - path_glob=["scripts/**","**/*.py"], language="python" - symbol="context_answer", under="scripts" """ + # Handle queries alias (explicit parameter) + if queries is not None and (query is None or (isinstance(query, str) and str(query).strip() == "")): + query = queries + # Accept common alias keys from clients (top-level) try: - if ( + if kwargs and ( limit is None or (isinstance(limit, str) and str(limit).strip() == "") ) and ("top_k" in kwargs): limit = kwargs.get("top_k") - if query is None or (isinstance(query, str) and str(query).strip() == ""): + if kwargs and (query is None or (isinstance(query, str) and str(query).strip() == "")): q_alt = kwargs.get("q") or kwargs.get("text") if q_alt is not None: query = q_alt @@ -1292,7 +1423,11 @@ async def repo_search( collection = _extra.get("collection") # Optional workspace_path routing if ( - (workspace_path is None) or (isinstance(workspace_path, str) and str(workspace_path).strip() == "") + (workspace_path is None) + or ( + isinstance(workspace_path, str) + and str(workspace_path).strip() == "" + ) ) and _extra.get("workspace_path") is not None: workspace_path = _extra.get("workspace_path") if ( @@ -1433,7 +1568,7 @@ def _to_str_list(x): # Accept top-level alias `queries` as a drop-in for `query` # Many clients send queries=[...] instead of query=[...] - if "queries" in kwargs and kwargs.get("queries") is not None: + if kwargs and "queries" in kwargs and kwargs.get("queries") is not None: query = kwargs.get("queries") # Normalize queries to a list[str] (robust for JSON strings and arrays) @@ -1479,7 +1614,11 @@ def _to_str_list(x): items = run_hybrid_search( queries=queries, limit=int(limit), - per_path=(int(per_path) if (per_path is not None and str(per_path).strip() != "") else 1), + per_path=( + int(per_path) + if (per_path is not None and str(per_path).strip() != "") + else 1 + ), language=language or None, under=under or None, kind=kind or None, @@ -1517,7 +1656,7 @@ def _to_str_list(x): str(int(limit)), "--json", ] - if (per_path is not None and str(per_path).strip() != ""): + if per_path is not None and str(per_path).strip() != "": cmd += ["--per-path", str(int(per_path))] if language: cmd += ["--language", language] @@ -1553,7 +1692,7 @@ def _to_str_list(x): except json.JSONDecodeError: continue # Fallback: if subprocess yielded nothing (e.g., local dev without /work), try in-process once - if (not json_lines): + if not json_lines: try: from scripts.hybrid_search import run_hybrid_search # type: ignore @@ -1562,7 +1701,11 @@ def _to_str_list(x): items = run_hybrid_search( queries=queries, limit=int(limit), - per_path=(int(per_path) if (per_path is not None and str(per_path).strip() != "") else 1), + per_path=( + int(per_path) + if (per_path is not None and str(per_path).strip() != "") + else 1 + ), language=language or None, under=under or None, kind=kind or None, @@ -1573,7 +1716,8 @@ def _to_str_list(x): path_regex=path_regex or None, path_glob=(path_globs or None), not_glob=(not_globs or None), - expand=str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() in {"1", "true", "yes", "on"}, + expand=str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() + in {"1", "true", "yes", "on"}, model=model, ) json_lines = items @@ -1722,8 +1866,18 @@ def _doc_for(obj: dict) -> str: _t_sec = max(0.1, _eff_ms / 1000.0) rres = await _run_async(rcmd, env=env, timeout=_t_sec) if os.environ.get("MCP_DEBUG_RERANK", "").strip(): - - logger.debug("RERANK_RET", extra={"code": rres.get("code"), "out_len": len((rres.get("stdout") or "").strip()), "err_tail": (rres.get("stderr") or "")[-200:]}) + logger.debug( + "RERANK_RET", + extra={ + "code": rres.get("code"), + "out_len": len((rres.get("stdout") or "").strip()), + "err_tail": (rres.get("stderr") or "")[-200:], + }, + ) + if not rres.get("ok"): + _stderr = (rres.get("stderr") or "").lower() + if rres.get("code") == -1 or "timed out" in _stderr: + rerank_counters["timeout"] += 1 if rres.get("ok") and (rres.get("stdout") or "").strip(): rerank_counters["subprocess"] += 1 tmp = [] @@ -1845,7 +1999,20 @@ def _read_snip(args): # Compact mode: return only path and line range if os.environ.get("DEBUG_REPO_SEARCH"): - logger.debug("DEBUG_REPO_SEARCH", extra={"count": len(results), "sample": [{"path": r.get("path"), "symbol": r.get("symbol"), "range": f"{r.get('start_line')}-{r.get('end_line')}"} for r in results[:5]]}) + logger.debug( + "DEBUG_REPO_SEARCH", + extra={ + "count": len(results), + "sample": [ + { + "path": r.get("path"), + "symbol": r.get("symbol"), + "range": f"{r.get('start_line')}-{r.get('end_line')}", + } + for r in results[:5] + ], + }, + ) if compact: results = [ @@ -1908,7 +2075,9 @@ async def repo_search_compat(**arguments) -> Dict[str, Any]: queries = args.get("queries") # top_k alias for limit limit = args.get("limit") - if (limit is None or (isinstance(limit, str) and str(limit).strip() == "")) and ("top_k" in args): + if ( + limit is None or (isinstance(limit, str) and str(limit).strip() == "") + ) and ("top_k" in args): limit = args.get("top_k") # not/ not_ normalization not_value = args.get("not_") if ("not_" in args) else args.get("not") @@ -1950,7 +2119,6 @@ async def repo_search_compat(**arguments) -> Dict[str, Any]: return {"error": f"repo_search_compat failed: {e}"} - @mcp.tool() async def context_answer_compat(arguments: Any = None) -> Dict[str, Any]: """Compatibility wrapper for context_answer (lenient argument handling). @@ -1995,7 +2163,6 @@ async def context_answer_compat(arguments: Any = None) -> Dict[str, Any]: return {"error": f"context_answer_compat failed: {e}"} - @mcp.tool() async def search_tests_for( query: Any = None, @@ -2005,7 +2172,7 @@ async def search_tests_for( under: Any = None, language: Any = None, compact: Any = None, - **kwargs, + kwargs: Any = None, ) -> Dict[str, Any]: """Find test files related to a query. @@ -2026,7 +2193,9 @@ async def search_tests_for( "**/Test*/**", ] # Allow caller to add more with path_glob kwarg - extra_glob = kwargs.get("path_glob") + # Handle kwargs being passed as a string by some MCP clients + _kwargs = _extract_kwargs_payload(kwargs) if kwargs else {} + extra_glob = _kwargs.get("path_glob") if extra_glob: if isinstance(extra_glob, (list, tuple)): globs.extend([str(x) for x in extra_glob]) @@ -2041,7 +2210,7 @@ async def search_tests_for( language=language, path_glob=globs, compact=compact, - **{k: v for k, v in kwargs.items() if k not in {"path_glob"}} + kwargs={k: v for k, v in _kwargs.items() if k not in {"path_glob"}}, ) @@ -2053,7 +2222,7 @@ async def search_config_for( context_lines: Any = None, under: Any = None, compact: Any = None, - **kwargs, + kwargs: Any = None, ) -> Dict[str, Any]: """Find likely configuration files for a service/query. @@ -2079,7 +2248,9 @@ async def search_config_for( "**/*.xml", "**/appsettings*.json", ] - extra_glob = kwargs.get("path_glob") + # Handle kwargs being passed as a string by some MCP clients + _kwargs = _extract_kwargs_payload(kwargs) if kwargs else {} + extra_glob = _kwargs.get("path_glob") if extra_glob: if isinstance(extra_glob, (list, tuple)): globs.extend([str(x) for x in extra_glob]) @@ -2093,7 +2264,7 @@ async def search_config_for( under=under, path_glob=globs, compact=compact, - **{k: v for k, v in kwargs.items() if k not in {"path_glob"}} + kwargs={k: v for k, v in _kwargs.items() if k not in {"path_glob"}}, ) @@ -2102,7 +2273,7 @@ async def search_callers_for( query: Any = None, limit: Any = None, language: Any = None, - **kwargs, + kwargs: Any = None, ) -> Dict[str, Any]: """Heuristic search for callers/usages of a symbol. @@ -2117,7 +2288,7 @@ async def search_callers_for( query=query, limit=limit, language=language, - **kwargs, + kwargs=kwargs, ) @@ -2126,7 +2297,7 @@ async def search_importers_for( query: Any = None, limit: Any = None, language: Any = None, - **kwargs, + kwargs: Any = None, ) -> Dict[str, Any]: """Find files likely importing or referencing a module/symbol. @@ -2137,11 +2308,27 @@ async def search_importers_for( Returns: repo_search result shape. """ globs = [ - "**/*.py", "**/*.js", "**/*.ts", "**/*.tsx", "**/*.jsx", "**/*.mjs", "**/*.cjs", - "**/*.go", "**/*.java", "**/*.cs", "**/*.rb", "**/*.php", "**/*.rs", - "**/*.c", "**/*.h", "**/*.cpp", "**/*.hpp", + "**/*.py", + "**/*.js", + "**/*.ts", + "**/*.tsx", + "**/*.jsx", + "**/*.mjs", + "**/*.cjs", + "**/*.go", + "**/*.java", + "**/*.cs", + "**/*.rb", + "**/*.php", + "**/*.rs", + "**/*.c", + "**/*.h", + "**/*.cpp", + "**/*.hpp", ] - extra_glob = kwargs.get("path_glob") + # Handle kwargs being passed as a string by some MCP clients + _kwargs = _extract_kwargs_payload(kwargs) if kwargs else {} + extra_glob = _kwargs.get("path_glob") if extra_glob: if isinstance(extra_glob, (list, tuple)): globs.extend([str(x) for x in extra_glob]) @@ -2153,11 +2340,10 @@ async def search_importers_for( limit=limit, language=language, path_glob=globs, - **{k: v for k, v in kwargs.items() if k not in {"path_glob"}} + kwargs={k: v for k, v in _kwargs.items() if k not in {"path_glob"}}, ) - @mcp.tool() async def change_history_for_path( path: Any, @@ -2185,13 +2371,19 @@ async def change_history_for_path( try: from qdrant_client import QdrantClient # type: ignore from qdrant_client import models as qmodels # type: ignore + client = QdrantClient( url=QDRANT_URL, api_key=os.environ.get("QDRANT_API_KEY"), timeout=float(os.environ.get("QDRANT_TIMEOUT", "20") or 20), ) + # Strict exact match on metadata.path (Compose maps to /work) filt = qmodels.Filter( - must=[qmodels.FieldCondition(key="metadata.path", match=qmodels.MatchValue(value=p))] + must=[ + qmodels.FieldCondition( + key="metadata.path", match=qmodels.MatchValue(value=p) + ) + ] ) page = None total = 0 @@ -2201,7 +2393,14 @@ async def change_history_for_path( churns = [] while total < mcap: sc, page = await asyncio.to_thread( - lambda: client.scroll(collection_name=coll, with_payload=True, with_vectors=False, limit=200, offset=page, scroll_filter=filt) + lambda: client.scroll( + collection_name=coll, + with_payload=True, + with_vectors=False, + limit=200, + offset=page, + scroll_filter=filt, + ) ) if not sc: break @@ -2267,7 +2466,7 @@ async def context_search( not_: Any = None, case: Any = None, compact: Any = None, - **kwargs, + kwargs: Any = None, ) -> Dict[str, Any]: """Blend code search results with memory-store entries (notes, docs) for richer context. @@ -2339,13 +2538,22 @@ def _maybe_dict(val: Any) -> Dict[str, Any]: for payload in payloads: if not isinstance(payload, dict): continue - if (query is None or (isinstance(query, str) and query.strip() == "")) and payload.get("query") is not None: + if ( + query is None or (isinstance(query, str) and query.strip() == "") + ) and payload.get("query") is not None: query = payload.get("query") - if (query is None or (isinstance(query, str) and query.strip() == "")) and payload.get("queries") is not None: + if ( + query is None or (isinstance(query, str) and query.strip() == "") + ) and payload.get("queries") is not None: query = payload.get("queries") - if (limit is None or (isinstance(limit, str) and limit.strip() == "")) and payload.get("limit") is not None: + if ( + limit is None or (isinstance(limit, str) and limit.strip() == "") + ) and payload.get("limit") is not None: limit = payload.get("limit") - if (per_path is None or (isinstance(per_path, str) and str(per_path).strip() == "")) and payload.get("per_path") is not None: + if ( + per_path is None + or (isinstance(per_path, str) and str(per_path).strip() == "") + ) and payload.get("per_path") is not None: per_path = payload.get("per_path") if include_memories is None and payload.get("include_memories") is not None: include_memories = payload.get("include_memories") @@ -2359,68 +2567,142 @@ def _maybe_dict(val: Any) -> Dict[str, Any]: per_source_limits = payload.get("per_source_limits") if per_source_limits is None and payload.get("perSourceLimits") is not None: per_source_limits = payload.get("perSourceLimits") - if (include_snippet is None or include_snippet == "") and payload.get("include_snippet") is not None: + if (include_snippet is None or include_snippet == "") and payload.get( + "include_snippet" + ) is not None: include_snippet = payload.get("include_snippet") - if (include_snippet is None or include_snippet == "") and payload.get("includeSnippet") is not None: + if (include_snippet is None or include_snippet == "") and payload.get( + "includeSnippet" + ) is not None: include_snippet = payload.get("includeSnippet") - if (context_lines is None or (isinstance(context_lines, str) and context_lines.strip() == "")) and payload.get("context_lines") is not None: + if ( + context_lines is None + or (isinstance(context_lines, str) and context_lines.strip() == "") + ) and payload.get("context_lines") is not None: context_lines = payload.get("context_lines") - if (context_lines is None or (isinstance(context_lines, str) and context_lines.strip() == "")) and payload.get("contextLines") is not None: + if ( + context_lines is None + or (isinstance(context_lines, str) and context_lines.strip() == "") + ) and payload.get("contextLines") is not None: context_lines = payload.get("contextLines") - if (rerank_enabled is None or rerank_enabled == "") and payload.get("rerank_enabled") is not None: + if (rerank_enabled is None or rerank_enabled == "") and payload.get( + "rerank_enabled" + ) is not None: rerank_enabled = payload.get("rerank_enabled") - if (rerank_enabled is None or rerank_enabled == "") and payload.get("rerankEnabled") is not None: + if (rerank_enabled is None or rerank_enabled == "") and payload.get( + "rerankEnabled" + ) is not None: rerank_enabled = payload.get("rerankEnabled") - if (rerank_top_n is None or (isinstance(rerank_top_n, str) and rerank_top_n.strip() == "")) and payload.get("rerank_top_n") is not None: + if ( + rerank_top_n is None + or (isinstance(rerank_top_n, str) and rerank_top_n.strip() == "") + ) and payload.get("rerank_top_n") is not None: rerank_top_n = payload.get("rerank_top_n") - if (rerank_top_n is None or (isinstance(rerank_top_n, str) and rerank_top_n.strip() == "")) and payload.get("rerankTopN") is not None: + if ( + rerank_top_n is None + or (isinstance(rerank_top_n, str) and rerank_top_n.strip() == "") + ) and payload.get("rerankTopN") is not None: rerank_top_n = payload.get("rerankTopN") - if (rerank_return_m is None or (isinstance(rerank_return_m, str) and rerank_return_m.strip() == "")) and payload.get("rerank_return_m") is not None: + if ( + rerank_return_m is None + or (isinstance(rerank_return_m, str) and rerank_return_m.strip() == "") + ) and payload.get("rerank_return_m") is not None: rerank_return_m = payload.get("rerank_return_m") - if (rerank_return_m is None or (isinstance(rerank_return_m, str) and rerank_return_m.strip() == "")) and payload.get("rerankReturnM") is not None: + if ( + rerank_return_m is None + or (isinstance(rerank_return_m, str) and rerank_return_m.strip() == "") + ) and payload.get("rerankReturnM") is not None: rerank_return_m = payload.get("rerankReturnM") - if (rerank_timeout_ms is None or (isinstance(rerank_timeout_ms, str) and rerank_timeout_ms.strip() == "")) and payload.get("rerank_timeout_ms") is not None: + if ( + rerank_timeout_ms is None + or (isinstance(rerank_timeout_ms, str) and rerank_timeout_ms.strip() == "") + ) and payload.get("rerank_timeout_ms") is not None: rerank_timeout_ms = payload.get("rerank_timeout_ms") - if (rerank_timeout_ms is None or (isinstance(rerank_timeout_ms, str) and rerank_timeout_ms.strip() == "")) and payload.get("rerankTimeoutMs") is not None: + if ( + rerank_timeout_ms is None + or (isinstance(rerank_timeout_ms, str) and rerank_timeout_ms.strip() == "") + ) and payload.get("rerankTimeoutMs") is not None: rerank_timeout_ms = payload.get("rerankTimeoutMs") - if (highlight_snippet is None or highlight_snippet == "") and payload.get("highlight_snippet") is not None: + if (highlight_snippet is None or highlight_snippet == "") and payload.get( + "highlight_snippet" + ) is not None: highlight_snippet = payload.get("highlight_snippet") - if (highlight_snippet is None or highlight_snippet == "") and payload.get("highlightSnippet") is not None: + if (highlight_snippet is None or highlight_snippet == "") and payload.get( + "highlightSnippet" + ) is not None: highlight_snippet = payload.get("highlightSnippet") - if (collection is None or (isinstance(collection, str) and collection.strip() == "")) and payload.get("collection") is not None: + if ( + collection is None + or (isinstance(collection, str) and collection.strip() == "") + ) and payload.get("collection") is not None: collection = payload.get("collection") - if (language is None or (isinstance(language, str) and language.strip() == "")) and payload.get("language") is not None: + if ( + language is None or (isinstance(language, str) and language.strip() == "") + ) and payload.get("language") is not None: language = payload.get("language") - if (under is None or (isinstance(under, str) and under.strip() == "")) and payload.get("under") is not None: + if ( + under is None or (isinstance(under, str) and under.strip() == "") + ) and payload.get("under") is not None: under = payload.get("under") - if (kind is None or (isinstance(kind, str) and kind.strip() == "")) and payload.get("kind") is not None: + if ( + kind is None or (isinstance(kind, str) and kind.strip() == "") + ) and payload.get("kind") is not None: kind = payload.get("kind") - if (symbol is None or (isinstance(symbol, str) and symbol.strip() == "")) and payload.get("symbol") is not None: + if ( + symbol is None or (isinstance(symbol, str) and symbol.strip() == "") + ) and payload.get("symbol") is not None: symbol = payload.get("symbol") - if (path_regex is None or (isinstance(path_regex, str) and path_regex.strip() == "")) and payload.get("path_regex") is not None: + if ( + path_regex is None + or (isinstance(path_regex, str) and path_regex.strip() == "") + ) and payload.get("path_regex") is not None: path_regex = payload.get("path_regex") - if (path_regex is None or (isinstance(path_regex, str) and path_regex.strip() == "")) and payload.get("pathRegex") is not None: + if ( + path_regex is None + or (isinstance(path_regex, str) and path_regex.strip() == "") + ) and payload.get("pathRegex") is not None: path_regex = payload.get("pathRegex") - if (path_glob is None or (isinstance(path_glob, str) and str(path_glob).strip() == "")) and payload.get("path_glob") is not None: + if ( + path_glob is None + or (isinstance(path_glob, str) and str(path_glob).strip() == "") + ) and payload.get("path_glob") is not None: path_glob = payload.get("path_glob") - if (path_glob is None or (isinstance(path_glob, str) and str(path_glob).strip() == "")) and payload.get("pathGlob") is not None: + if ( + path_glob is None + or (isinstance(path_glob, str) and str(path_glob).strip() == "") + ) and payload.get("pathGlob") is not None: path_glob = payload.get("pathGlob") - if (not_glob is None or (isinstance(not_glob, str) and str(not_glob).strip() == "")) and payload.get("not_glob") is not None: + if ( + not_glob is None + or (isinstance(not_glob, str) and str(not_glob).strip() == "") + ) and payload.get("not_glob") is not None: not_glob = payload.get("not_glob") - if (not_glob is None or (isinstance(not_glob, str) and str(not_glob).strip() == "")) and payload.get("notGlob") is not None: + if ( + not_glob is None + or (isinstance(not_glob, str) and str(not_glob).strip() == "") + ) and payload.get("notGlob") is not None: not_glob = payload.get("notGlob") - if (ext is None or (isinstance(ext, str) and ext.strip() == "")) and payload.get("ext") is not None: + if ( + ext is None or (isinstance(ext, str) and ext.strip() == "") + ) and payload.get("ext") is not None: ext = payload.get("ext") - if (not_ is None or (isinstance(not_, str) and not_.strip() == "")) and payload.get("not") is not None: + if ( + not_ is None or (isinstance(not_, str) and not_.strip() == "") + ) and payload.get("not") is not None: not_ = payload.get("not") - if (not_ is None or (isinstance(not_, str) and not_.strip() == "")) and payload.get("not_") is not None: + if ( + not_ is None or (isinstance(not_, str) and not_.strip() == "") + ) and payload.get("not_") is not None: not_ = payload.get("not_") - if (case is None or (isinstance(case, str) and case.strip() == "")) and payload.get("case") is not None: + if ( + case is None or (isinstance(case, str) and case.strip() == "") + ) and payload.get("case") is not None: case = payload.get("case") - if (compact is None or (isinstance(compact, str) and compact.strip() == "")) and payload.get("compact") is not None: + if ( + compact is None or (isinstance(compact, str) and compact.strip() == "") + ) and payload.get("compact") is not None: compact = payload.get("compact") - # Leniency: absorb nested 'kwargs' JSON payload some clients send (string or dict) try: _extra = _extract_kwargs_payload(kwargs) @@ -2433,55 +2715,95 @@ def _maybe_dict(val: Any) -> Dict[str, Any]: per_path = _extra.get("per_path") # Memory blending controls if include_memories is None and ( - (_extra.get("include_memories") is not None) or (_extra.get("includeMemories") is not None) + (_extra.get("include_memories") is not None) + or (_extra.get("includeMemories") is not None) ): - include_memories = _extra.get("include_memories", _extra.get("includeMemories")) + include_memories = _extra.get( + "include_memories", _extra.get("includeMemories") + ) if memory_weight is None and ( - (_extra.get("memory_weight") is not None) or (_extra.get("memoryWeight") is not None) + (_extra.get("memory_weight") is not None) + or (_extra.get("memoryWeight") is not None) ): memory_weight = _extra.get("memory_weight", _extra.get("memoryWeight")) if per_source_limits is None and ( - (_extra.get("per_source_limits") is not None) or (_extra.get("perSourceLimits") is not None) + (_extra.get("per_source_limits") is not None) + or (_extra.get("perSourceLimits") is not None) ): - per_source_limits = _extra.get("per_source_limits", _extra.get("perSourceLimits")) + per_source_limits = _extra.get( + "per_source_limits", _extra.get("perSourceLimits") + ) # Passthrough search filters - if (include_snippet in (None, "")) and (_extra.get("include_snippet") is not None): + if (include_snippet in (None, "")) and ( + _extra.get("include_snippet") is not None + ): include_snippet = _extra.get("include_snippet") - if (context_lines in (None, "")) and (_extra.get("context_lines") is not None): + if (context_lines in (None, "")) and ( + _extra.get("context_lines") is not None + ): context_lines = _extra.get("context_lines") - if (rerank_enabled in (None, "")) and (_extra.get("rerank_enabled") is not None): + if (rerank_enabled in (None, "")) and ( + _extra.get("rerank_enabled") is not None + ): rerank_enabled = _extra.get("rerank_enabled") - if (rerank_top_n in (None, "")) and (_extra.get("rerank_top_n") is not None): + if (rerank_top_n in (None, "")) and ( + _extra.get("rerank_top_n") is not None + ): rerank_top_n = _extra.get("rerank_top_n") - if (rerank_return_m in (None, "")) and (_extra.get("rerank_return_m") is not None): + if (rerank_return_m in (None, "")) and ( + _extra.get("rerank_return_m") is not None + ): rerank_return_m = _extra.get("rerank_return_m") - if (rerank_timeout_ms in (None, "")) and (_extra.get("rerank_timeout_ms") is not None): + if (rerank_timeout_ms in (None, "")) and ( + _extra.get("rerank_timeout_ms") is not None + ): rerank_timeout_ms = _extra.get("rerank_timeout_ms") - if (highlight_snippet in (None, "")) and (_extra.get("highlight_snippet") is not None): + if (highlight_snippet in (None, "")) and ( + _extra.get("highlight_snippet") is not None + ): highlight_snippet = _extra.get("highlight_snippet") - if (collection is None or (isinstance(collection, str) and collection.strip() == "")) and _extra.get("collection"): + if ( + collection is None + or (isinstance(collection, str) and collection.strip() == "") + ) and _extra.get("collection"): collection = _extra.get("collection") - if (language is None or (isinstance(language, str) and language.strip() == "")) and _extra.get("language"): + if ( + language is None + or (isinstance(language, str) and language.strip() == "") + ) and _extra.get("language"): language = _extra.get("language") - if (under is None or (isinstance(under, str) and under.strip() == "")) and _extra.get("under"): + if ( + under is None or (isinstance(under, str) and under.strip() == "") + ) and _extra.get("under"): under = _extra.get("under") - if (kind is None or (isinstance(kind, str) and kind.strip() == "")) and _extra.get("kind"): + if ( + kind is None or (isinstance(kind, str) and kind.strip() == "") + ) and _extra.get("kind"): kind = _extra.get("kind") - if (symbol is None or (isinstance(symbol, str) and symbol.strip() == "")) and _extra.get("symbol"): + if ( + symbol is None or (isinstance(symbol, str) and symbol.strip() == "") + ) and _extra.get("symbol"): symbol = _extra.get("symbol") - if (path_regex is None or (isinstance(path_regex, str) and path_regex.strip() == "")) and _extra.get("path_regex"): + if ( + path_regex is None + or (isinstance(path_regex, str) and path_regex.strip() == "") + ) and _extra.get("path_regex"): path_regex = _extra.get("path_regex") if (path_glob in (None, "")) and (_extra.get("path_glob") is not None): path_glob = _extra.get("path_glob") if (not_glob in (None, "")) and (_extra.get("not_glob") is not None): not_glob = _extra.get("not_glob") - if (ext is None or (isinstance(ext, str) and ext.strip() == "")) and _extra.get("ext"): + if ( + ext is None or (isinstance(ext, str) and ext.strip() == "") + ) and _extra.get("ext"): ext = _extra.get("ext") if (not_ is None or (isinstance(not_, str) and not_.strip() == "")) and ( _extra.get("not") or _extra.get("not_") ): not_ = _extra.get("not") or _extra.get("not_") - if (case is None or (isinstance(case, str) and case.strip() == "")) and _extra.get("case"): + if ( + case is None or (isinstance(case, str) and case.strip() == "") + ) and _extra.get("case"): case = _extra.get("case") if (compact in (None, "")) and (_extra.get("compact") is not None): compact = _extra.get("compact") @@ -2597,15 +2919,15 @@ def _maybe_dict(val: Any) -> Dict[str, Any]: queries = [str(query).strip()] # Accept common alias keys and camelCase from clients - if (limit is None or (isinstance(limit, str) and limit.strip() == "")) and ( + if kwargs and (limit is None or (isinstance(limit, str) and limit.strip() == "")) and ( "top_k" in kwargs ): limit = kwargs.get("top_k") - if include_memories is None and ("includeMemories" in kwargs): + if kwargs and include_memories is None and ("includeMemories" in kwargs): include_memories = kwargs.get("includeMemories") - if memory_weight is None and ("memoryWeight" in kwargs): + if kwargs and memory_weight is None and ("memoryWeight" in kwargs): memory_weight = kwargs.get("memoryWeight") - if per_source_limits is None and ("perSourceLimits" in kwargs): + if kwargs and per_source_limits is None and ("perSourceLimits" in kwargs): per_source_limits = kwargs.get("perSourceLimits") # Smart defaults inspired by stored preferences, but without external calls @@ -2720,10 +3042,14 @@ def _maybe_dict(val: Any) -> Dict[str, Any]: "DBG_CTX_SRCH_CODE_RES", extra={ "type": type(code_res).__name__, - "has_results": bool(isinstance(code_res, dict) and isinstance(code_res.get("results"), list)), + "has_results": bool( + isinstance(code_res, dict) + and isinstance(code_res.get("results"), list) + ), "len_results": ( len(code_res.get("results")) - if isinstance(code_res, dict) and isinstance(code_res.get("results"), list) + if isinstance(code_res, dict) + and isinstance(code_res.get("results"), list) else None ), "code_hits": len(code_hits), @@ -2737,13 +3063,20 @@ def _maybe_dict(val: Any) -> Dict[str, Any]: if not code_hits: try: from scripts.mcp_router import call_tool_http # type: ignore - base = (os.environ.get("MCP_INDEXER_HTTP_URL") or "http://localhost:8003/mcp").rstrip("/") + + base = ( + os.environ.get("MCP_INDEXER_HTTP_URL") or "http://localhost:8003/mcp" + ).rstrip("/") http_args = { - "query": (queries if len(queries) > 1 else (queries[0] if queries else "")), + "query": ( + queries if len(queries) > 1 else (queries[0] if queries else "") + ), "limit": int(code_limit), "per_path": int(per_path_val), "include_snippet": bool(include_snippet), - "context_lines": int(context_lines) if context_lines not in (None, "") else 2, + "context_lines": int(context_lines) + if context_lines not in (None, "") + else 2, "collection": coll, "language": language or "", "under": under or "", @@ -2758,8 +3091,12 @@ def _maybe_dict(val: Any) -> Dict[str, Any]: "compact": bool(eff_compact), } timeout = float(os.environ.get("CONTEXT_SEARCH_HTTP_TIMEOUT", "20") or 20) - resp = await asyncio.to_thread(lambda: call_tool_http(base, "repo_search", http_args, timeout=timeout)) - r = ((resp.get("result") or {}).get("structuredContent") or {}).get("result") or {} + resp = await asyncio.to_thread( + lambda: call_tool_http(base, "repo_search", http_args, timeout=timeout) + ) + r = ((resp.get("result") or {}).get("structuredContent") or {}).get( + "result" + ) or {} http_items = r.get("results") or [] if isinstance(http_items, list): for obj in http_items: @@ -2778,7 +3115,9 @@ def _maybe_dict(val: Any) -> Dict[str, Any]: used_http_fallback = True if os.environ.get("DEBUG_CONTEXT_SEARCH"): try: - logger.debug("DBG_CTX_SRCH_HTTP_FALLBACK", extra={"count": len(code_hits)}) + logger.debug( + "DBG_CTX_SRCH_HTTP_FALLBACK", extra={"count": len(code_hits)} + ) except Exception: pass except Exception: @@ -2789,6 +3128,7 @@ def _maybe_dict(val: Any) -> Dict[str, Any]: if not code_hits and queries: try: from scripts.hybrid_search import run_hybrid_search # type: ignore + model_name = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-base-en-v1.5") model = _get_embedding_model(model_name) prev_coll = os.environ.get("COLLECTION_NAME") @@ -2808,7 +3148,8 @@ def _maybe_dict(val: Any) -> Dict[str, Any]: path_regex=path_regex or None, path_glob=path_glob or None, not_glob=not_glob or None, - expand=str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() in {"1", "true", "yes", "on"}, + expand=str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() + in {"1", "true", "yes", "on"}, model=model, ) finally: @@ -3260,9 +3601,19 @@ def push_text( "include_memories": bool(include_mem), "memory_weight": float(mw), "include_snippet": bool(include_snippet), - "context_lines": int(context_lines) if context_lines not in (None, "") else 2, + "context_lines": int(context_lines) + if context_lines not in (None, "") + else 2, "compact": bool(eff_compact), } + try: + if isinstance(code_res, dict): + ret["diag"]["rerank"] = { + "used_rerank": bool(code_res.get("used_rerank")), + "counters": code_res.get("rerank_counters") or {}, + } + except Exception: + pass return ret ret = {"results": blended, "total": len(blended)} @@ -3286,6 +3637,8 @@ def push_text( "compact": bool(eff_compact), } return ret + + @mcp.tool() async def expand_query(query: Any = None, max_new: Any = None) -> Dict[str, Any]: """LLM-assisted query expansion (local llama.cpp, if enabled). @@ -3312,18 +3665,13 @@ async def expand_query(query: Any = None, max_new: Any = None) -> Dict[str, Any] cap = max(0, min(2, int(max_new))) except (ValueError, TypeError): cap = 2 - runtime = str(os.environ.get("REFRAG_RUNTIME", "llamacpp")).strip().lower() - dec_on = str(os.environ.get("REFRAG_DECODER", "0")).strip().lower() in {"1", "true", "yes", "on"} - if runtime == "glm": - if not (dec_on and os.environ.get("GLM_API_KEY")): - return {"alternates": [], "hint": "decoder disabled: set REFRAG_DECODER=1 and GLM_API_KEY; or use llamacpp"} - from scripts.refrag_glm import GLMRefragClient # type: ignore - client = GLMRefragClient() - else: - from scripts.refrag_llamacpp import LlamaCppRefragClient, is_decoder_enabled # type: ignore - if not is_decoder_enabled(): - return {"alternates": [], "hint": "decoder disabled: set REFRAG_DECODER=1 and start llamacpp (LLAMACPP_URL)"} - client = LlamaCppRefragClient() + from scripts.refrag_llamacpp import LlamaCppRefragClient, is_decoder_enabled # type: ignore + + if not is_decoder_enabled(): + return { + "alternates": [], + "hint": "decoder disabled: set REFRAG_DECODER=1 and start llamacpp (LLAMACPP_URL)", + } if not qlist: return {"alternates": []} prompt = ( @@ -3340,6 +3688,7 @@ async def expand_query(query: Any = None, max_new: Any = None) -> Dict[str, Any] stop=["\n\n"], ) import json as _json + alts: list[str] = [] try: parsed = _json.loads(out) @@ -3353,14 +3702,32 @@ async def expand_query(query: Any = None, max_new: Any = None) -> Dict[str, Any] pass return {"alternates": alts} except Exception as e: + fallback_alts: list[str] = [] + for q in qlist: + q = q.strip() + if not q: + continue + for suffix in (" implementation", " usage", " example", " test"): + cand = f"{q}{suffix}" + if cand not in qlist and cand not in fallback_alts: + fallback_alts.append(cand) + if len(fallback_alts) >= cap: + break + if len(fallback_alts) >= cap: + break + if fallback_alts: + return { + "alternates": fallback_alts[:cap], + "hint": f"decoder fallback: {e}", + } return {"alternates": [], "error": str(e)} - # Lightweight cleanup to reduce repetition from small models def _cleanup_answer(text: str, max_chars: int | None = None) -> str: try: import re + t = (text or "").strip() if not t: return t @@ -3408,7 +3775,7 @@ def _cleanup_answer(text: str, max_chars: int | None = None) -> str: t2 = " ".join(out) # Optional final cap if max_chars and max_chars > 0 and len(t2) > max_chars: - t2 = t2[: max(0, max_chars - 3) ] + "..." + t2 = t2[: max(0, max_chars - 3)] + "..." return t2 except Exception: return text @@ -3420,7 +3787,9 @@ def _answer_style_guidance() -> str: return ( "Write a direct answer in 2-4 sentences. No headings or labels. " "Ground non-trivial claims with bracketed citations like [n] using the numbered Sources. " - "Quote exact code lines when relevant. If the snippets are insufficient, respond exactly: insufficient context." + "Never invent functions or parameters that do not appear in the snippets. " + "Do not include URLs or Markdown links of any kind; cite only with [n]. " + "If the Sources list is empty or the snippets are insufficient, respond exactly: insufficient context." ) @@ -3445,27 +3814,45 @@ def _validate_answer_output(text: str, citations: list) -> dict: try: t = (text or "").strip() low = t.lower() - has_cite = ("[" in t and "]" in t) or not citations + requires_cite = bool(citations) + has_refs = "[" in t and "]" in t + is_insufficient = low == "insufficient context" hedge_terms = ["likely", "might", "could", "appears", "seems", "probably"] hedge_score = sum(low.count(w) for w in hedge_terms) # Configurable cutoff: allow citation/quote/paren endings and tune min length via CTX_CUTOFF_MIN_CHARS (default 220) - MIN = safe_int(os.environ.get("CTX_CUTOFF_MIN_CHARS", ""), default=220, logger=logger, context="CTX_CUTOFF_MIN_CHARS") + MIN = safe_int( + os.environ.get("CTX_CUTOFF_MIN_CHARS", ""), + default=220, + logger=logger, + context="CTX_CUTOFF_MIN_CHARS", + ) valid_end = (".", "!", "?", "]", '"', "'", "”", "’", ")") tail = t.rstrip() - looks_cutoff = (len(tail) > MIN and not tail.endswith(valid_end)) - ok = bool(t) and (has_cite or not citations) and hedge_score < 4 and not looks_cutoff + looks_cutoff = len(tail) > MIN and not tail.endswith(valid_end) + ok = ( + bool(t) + and (is_insufficient or (requires_cite and has_refs)) + and hedge_score < 4 + and not looks_cutoff + ) return { "ok": ok, - "has_citation_refs": has_cite, + "has_citation_refs": (has_refs or is_insufficient), "hedge_score": hedge_score, "looks_cutoff": looks_cutoff, } except Exception: - return {"ok": True, "has_citation_refs": True, "hedge_score": 0, "looks_cutoff": False} + return { + "ok": True, + "has_citation_refs": True, + "hedge_score": 0, + "looks_cutoff": False, + } # ----- context_answer refactor helpers ----- + def _ca_unwrap_and_normalize( query: Any, limit: Any, @@ -3501,7 +3888,11 @@ def _ca_unwrap_and_normalize( for kk, vv in v.items(): _raw.setdefault(kk, vv) except (TypeError, AttributeError) as e: - logger.warning("Failed to unwrap nested kwargs", exc_info=e, extra={"raw_keys": list(_raw.keys())}) + logger.warning( + "Failed to unwrap nested kwargs", + exc_info=e, + extra={"raw_keys": list(_raw.keys())}, + ) # Prefer non-empty override from wrapper def _coalesce(val, fallback): @@ -3533,7 +3924,11 @@ def _coalesce(val, fallback): path_glob = _coalesce(_raw.get("path_glob"), path_glob) not_glob = _coalesce(_raw.get("not_glob"), not_glob) case = _coalesce(_raw.get("case"), case) - not_ = _coalesce(_raw.get("not_"), not_) if _raw.get("not_") is not None else _coalesce(_raw.get("not"), not_) + not_ = ( + _coalesce(_raw.get("not_"), not_) + if _raw.get("not_") is not None + else _coalesce(_raw.get("not"), not_) + ) # Normalize query to list[str] queries: list[str] = [] @@ -3547,7 +3942,9 @@ def _coalesce(val, fallback): if s: queries = [s] except (TypeError, ValueError) as e: - logger.warning("Failed to normalize query", exc_info=e, extra={"raw_query": query}) + logger.warning( + "Failed to normalize query", exc_info=e, extra={"raw_query": query} + ) raise ValidationError(f"Invalid query format: {e}") if not queries: @@ -3560,6 +3957,7 @@ def _coalesce(val, fallback): # Adjust per_path for identifier-focused questions try: import re as _re + _ids0 = _re.findall(r"\b([A-Z_][A-Z0-9_]{2,})\b", " ".join(queries)) if _ids0: ppath = max(ppath, 5) @@ -3596,7 +3994,6 @@ def _coalesce(val, fallback): } - def _ca_prepare_filters_and_retrieve( queries: list[str], lim: int, @@ -3648,6 +4045,7 @@ def _variants(p: str) -> list[str]: default_not_glob.extend(_variants(b)) qtext = " ".join(queries).lower() + def _mentions_any(keys: list[str]) -> bool: return any(k in qtext for k in keys) @@ -3655,11 +4053,33 @@ def _mentions_any(keys: list[str]) -> bool: if not _mentions_any([".env", "dotenv", "environment variable", "env var"]): maybe_excludes += [".env", ".env.*"] if not _mentions_any(["docker-compose", "compose"]): - maybe_excludes += ["docker-compose*.yml", "docker-compose*.yaml", "compose*.yml", "compose*.yaml"] - if not _mentions_any(["lock", "package-lock.json", "pnpm-lock", "yarn.lock", "poetry.lock", "cargo.lock", "go.sum", "composer.lock"]): maybe_excludes += [ - "*.lock", "package-lock.json", "pnpm-lock.yaml", "yarn.lock", - "poetry.lock", "Cargo.lock", "go.sum", "composer.lock", + "docker-compose*.yml", + "docker-compose*.yaml", + "compose*.yml", + "compose*.yaml", + ] + if not _mentions_any( + [ + "lock", + "package-lock.json", + "pnpm-lock", + "yarn.lock", + "poetry.lock", + "cargo.lock", + "go.sum", + "composer.lock", + ] + ): + maybe_excludes += [ + "*.lock", + "package-lock.json", + "pnpm-lock.yaml", + "yarn.lock", + "poetry.lock", + "Cargo.lock", + "go.sum", + "composer.lock", ] if not _mentions_any(["appsettings", "settings.json", "config"]): maybe_excludes += ["appsettings*.json"] @@ -3688,6 +4108,21 @@ def _to_glob_list(val: Any) -> list[str]: eff_path_glob: list[str] = list(user_path_glob) auto_path_glob: list[str] = [] + # Heuristic: detect explicit file mentions in the queries and bias retrieval + try: + import re as _re + mentioned = _re.findall(r"([A-Za-z0-9_./-]+\.[A-Za-z0-9_]+)", qtext) + for m in mentioned: + mm = str(m).replace('\\\\','/').lstrip('/') + if not mm: + continue + fn = mm.split('/')[-1] + # Prefer filename and full relative path variants + auto_path_glob.append(f"**/{fn}") + auto_path_glob.append(f"**/{mm}") + except Exception: + pass + def _abs_prefix(val: str) -> str: v = (val or "").replace("\\", "/") if not v: @@ -3737,12 +4172,18 @@ def _abs_prefix(val: str) -> str: try: qj = " ".join(queries) import re as _re + primary = _primary_identifier_from_queries(queries) - if primary and any(word in qj.lower() for word in ["what is", "how is", "used", "usage", "define"]): + if primary and any( + word in qj.lower() + for word in ["what is", "how is", "used", "usage", "define"] + ): + def _add_query(q: str): qs = q.strip() if qs and qs not in queries: queries.append(qs) + _add_query(primary) _add_query(f"{primary} =") func_name = primary.lower().split("_")[0] @@ -3752,7 +4193,14 @@ def _add_query(q: str): logger.debug("Failed to augment query with identifier probes", exc_info=e) if os.environ.get("DEBUG_CONTEXT_ANSWER"): - logger.debug("FILTERS", extra={"language": req_language, "override_under": override_under, "path_glob": eff_path_glob}) + logger.debug( + "FILTERS", + extra={ + "language": req_language, + "override_under": override_under, + "path_glob": eff_path_glob, + }, + ) # Sanitize symbol sym_arg = kwargs.get("symbol") or filters.get("symbol") or None @@ -3764,6 +4212,7 @@ def _add_query(q: str): # Run retrieval from scripts.hybrid_search import run_hybrid_search # type: ignore + items = run_hybrid_search( queries=queries, limit=int(max(lim, 4)), @@ -3773,23 +4222,36 @@ def _add_query(q: str): kind=(kind or kwargs.get("kind") or None), symbol=sym_arg, ext=(ext or kwargs.get("ext") or None), - not_filter=(filters.get("not_") or kwargs.get("not_") or kwargs.get("not") or None), + not_filter=( + filters.get("not_") or kwargs.get("not_") or kwargs.get("not") or None + ), case=(case or kwargs.get("case") or None), path_regex=(path_regex or kwargs.get("path_regex") or None), path_glob=(eff_path_glob or None), not_glob=eff_not_glob, - expand=False if did_local_expand else (str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() in {"1","true","yes","on"}), + expand=False + if did_local_expand + else ( + str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() + in {"1", "true", "yes", "on"} + ), model=model, ) if os.environ.get("DEBUG_CONTEXT_ANSWER"): try: - print("[DEBUG] TIER1 items:", len(items), "first path:", (items[0].get("path") if items else None)) + print( + "[DEBUG] TIER1 items:", + len(items), + "first path:", + (items[0].get("path") if items else None), + ) except Exception: pass # Usage augmentation for identifier try: import re as _re + qj2 = " ".join(queries) _ids = _re.findall(r"\b([A-Z_][A-Z0-9_]{2,})\b", qj2) _asked = _ids[0] if _ids else "" @@ -3798,10 +4260,16 @@ def _add_query(q: str): _usage_qs: list[str] = [] if _fname and len(_fname) >= 2: _usage_qs.append(f"def {_fname}(") - _usage_qs.extend([ - f"{_asked})", f"{_asked},", f"= {_asked}", f"{_asked} =", - f"{_asked} = int(os.environ.get", f"int(os.environ.get(\"{_asked}\"", - ]) + _usage_qs.extend( + [ + f"{_asked})", + f"{_asked},", + f"= {_asked}", + f"{_asked} =", + f"{_asked} = int(os.environ.get", + f'int(os.environ.get("{_asked}"', + ] + ) _usage_qs = [u for u in _usage_qs if u and u not in queries] if _usage_qs: usage_items = run_hybrid_search( @@ -3813,17 +4281,33 @@ def _add_query(q: str): kind=(kind or kwargs.get("kind") or None), symbol=sym_arg, ext=(ext or kwargs.get("ext") or None), - not_filter=(filters.get("not_") or kwargs.get("not_") or kwargs.get("not") or None), + not_filter=( + filters.get("not_") + or kwargs.get("not_") + or kwargs.get("not") + or None + ), case=(case or kwargs.get("case") or None), path_regex=(path_regex or kwargs.get("path_regex") or None), path_glob=(eff_path_glob or None), not_glob=eff_not_glob, - expand=False if did_local_expand else (str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() in {"1","true","yes","on"}), + expand=False + if did_local_expand + else ( + str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() + in {"1", "true", "yes", "on"} + ), model=model, ) + def _ikey(it: Dict[str, Any]): - return (str(it.get("path") or ""), int(it.get("start_line") or 0), int(it.get("end_line") or 0)) - _seen = { _ikey(it) for it in items } + return ( + str(it.get("path") or ""), + int(it.get("start_line") or 0), + int(it.get("end_line") or 0), + ) + + _seen = {_ikey(it) for it in items} for it in usage_items: k = _ikey(it) if k not in _seen: @@ -3840,12 +4324,22 @@ def _ikey(it: Dict[str, Any]): kind=(kind or kwargs.get("kind") or None), symbol=sym_arg, ext=(ext or kwargs.get("ext") or None), - not_filter=(filters.get("not_") or kwargs.get("not_") or kwargs.get("not") or None), + not_filter=( + filters.get("not_") + or kwargs.get("not_") + or kwargs.get("not") + or None + ), case=(case or kwargs.get("case") or None), path_regex=(path_regex or kwargs.get("path_regex") or None), path_glob=(eff_path_glob or None), not_glob=eff_not_glob, - expand=False if did_local_expand else (str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() in {"1","true","yes","on"}), + expand=False + if did_local_expand + else ( + str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() + in {"1", "true", "yes", "on"} + ), model=model, ) @@ -3858,6 +4352,7 @@ def _ikey(it: Dict[str, Any]): from scripts.hybrid_search import lang_matches_path as _lmp # type: ignore except Exception: _lmp = None + def _ok_lang(it: Dict[str, Any]) -> bool: p = str(it.get("path") or "") if callable(_lmp): @@ -3886,8 +4381,55 @@ def _ok_lang(it: Dict[str, Any]) -> bool: } lang_exts = table.get(str(req_language).lower(), []) return any(ext in lang_exts for ext in extensions) + items = [it for it in items if _ok_lang(it)] + # Targeted fallback: if query mentions a specific path and it's missing from results, add a small span from that file + try: + import re as _re + mentioned = _re.findall(r"([A-Za-z0-9_./-]+\.[A-Za-z0-9_]+)", qtext) + if mentioned: + # Normalize to repo-relative paths + def _normp(p: str) -> str: + p = str(p).replace('\\\\','/').lstrip('/') + return p + mentioned = [_normp(m) for m in mentioned if m] + have_paths = {str(it.get('path') or '').lstrip('/') for it in items} + for m in mentioned: + if m in have_paths: + continue + abs_path = m if os.path.isabs(m) else os.path.join(cwd_root, m) + if not os.path.exists(abs_path): + continue + try: + with open(abs_path, 'r', encoding='utf-8', errors='ignore') as f: + lines = f.readlines() + primary = _primary_identifier_from_queries(queries) + start = 1 + end = min(len(lines), start + 20) + if primary and len(primary) >= 3: + for idx, line in enumerate(lines, 1): + if _re.search(rf"\b{_re.escape(primary)}\b\s*[=:(]", line): + start = max(1, idx - 2) + end = min(len(lines), idx + 8) + break + snippet_text = "".join(lines[start-1:end]).strip() + if snippet_text: + items.append({ + 'path': m, + 'start_line': start, + 'end_line': end, + 'text': snippet_text, + 'score': 1.0, + 'tier': 'path_mention', + 'language': req_language or None, + 'kind': 'definition', + }) + except Exception: + pass + except Exception: + pass + return { "items": items, "eff_language": req_language, @@ -3903,9 +4445,6 @@ def _ok_lang(it: Dict[str, Any]) -> bool: } - - - def _ca_fallback_and_budget( *, items: list[Dict[str, Any]], @@ -3953,8 +4492,8 @@ def _ok_lang(it: Dict[str, Any]) -> bool: if len(parts) > 1: extensions.add(parts[-1].lower()) if len(parts) > 2: - # DEBUG: marker to observe fallback invocation in tests - # print will be captured by pytest -s only + # DEBUG: marker to observe fallback invocation in tests + # print will be captured by pytest -s only multi_ext = ".".join(parts[-2:]).lower() extensions.add(multi_ext) @@ -3983,6 +4522,7 @@ def _ok_lang(it: Dict[str, Any]) -> bool: extra={"stage": "tier2"}, ) from scripts.hybrid_search import run_hybrid_search # type: ignore + with _env_overrides({"REFRAG_GATE_FIRST": "0"}): items = run_hybrid_search( queries=queries, @@ -4021,66 +4561,105 @@ def _ok_lang(it: Dict[str, Any]) -> bool: path_regex=None, path_glob=None, not_glob=eff_not_glob, - expand=False if did_local_expand else (str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() in {"1", "true", "yes", "on"}), + expand=False + if did_local_expand + else ( + str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() + in {"1", "true", "yes", "on"} + ), model=model, ) if os.environ.get("DEBUG_CONTEXT_ANSWER"): - logger.debug("TIER2: broader hybrid returned items", extra={"count": len(items)}) + logger.debug( + "TIER2: broader hybrid returned items", extra={"count": len(items)} + ) try: - print("[DEBUG] TIER2 items:", len(items), "first path:", (items[0].get("path") if items else None)) + print( + "[DEBUG] TIER2 items:", + len(items), + "first path:", + (items[0].get("path") if items else None), + ) except Exception: pass # Multi-collection fallback: index-only search across other workspaces/collections try: - _mc_enabled = str(os.environ.get("CTX_MULTI_COLLECTION", "1")).strip().lower() in {"1","true","yes","on"} + _mc_enabled = str( + os.environ.get("CTX_MULTI_COLLECTION", "1") + ).strip().lower() in {"1", "true", "yes", "on"} if _mc_enabled and (len(items) < max(2, int(lim) // 2)): # Discover other workspace collections (search parent of cwd by default) from scripts.workspace_state import list_workspaces as _ws_list_workspaces # type: ignore + try: _sr = os.environ.get("WORKSPACE_SEARCH_ROOT") if not _sr: from pathlib import Path as _Path + _sr = str(_Path(os.getcwd()).resolve().parent) except Exception: _sr = "/work" _workspaces = _ws_list_workspaces(_sr) or [] _current_coll = os.environ.get("COLLECTION_NAME") or "" - _colls = [w.get("collection_name") for w in _workspaces if isinstance(w, dict) and w.get("collection_name")] - _colls = [c for c in _colls if isinstance(c, str) and c.strip() and c.strip() != _current_coll] - _maxc = safe_int(os.environ.get("CTX_MAX_COLLECTIONS", "4"), default=4, logger=logger, context="CTX_MAX_COLLECTIONS") + _colls = [ + w.get("collection_name") + for w in _workspaces + if isinstance(w, dict) and w.get("collection_name") + ] + _colls = [ + c + for c in _colls + if isinstance(c, str) and c.strip() and c.strip() != _current_coll + ] + _maxc = safe_int( + os.environ.get("CTX_MAX_COLLECTIONS", "4"), + default=4, + logger=logger, + context="CTX_MAX_COLLECTIONS", + ) _colls = _colls[: max(0, _maxc)] if _colls: from scripts.hybrid_search import run_hybrid_search as _rhs # type: ignore + _agg: list[Dict[str, Any]] = [] for _c in _colls: try: with _env_overrides({"COLLECTION_NAME": _c}): - _res = _rhs( - queries=queries, - limit=int(max(lim, 8)), - per_path=int(max(ppath, 2)), - language=eff_language, - under=override_under or None, - kind=kind or None, - symbol=sym_arg or None, - ext=ext or None, - not_filter=not_ or None, - case=case or None, - path_regex=path_regex or None, - path_glob=eff_path_glob, - not_glob=eff_not_glob, - expand=str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() in {"1","true","yes","on"}, - model=model, - ) or [] + _res = ( + _rhs( + queries=queries, + limit=int(max(lim, 8)), + per_path=int(max(ppath, 2)), + language=eff_language, + under=override_under or None, + kind=kind or None, + symbol=sym_arg or None, + ext=ext or None, + not_filter=not_ or None, + case=case or None, + path_regex=path_regex or None, + path_glob=eff_path_glob, + not_glob=eff_not_glob, + expand=str(os.environ.get("HYBRID_EXPAND", "0")) + .strip() + .lower() + in {"1", "true", "yes", "on"}, + model=model, + ) + or [] + ) for _it in _res: if isinstance(_it, dict): _agg.append(_it) except Exception: if os.environ.get("DEBUG_CONTEXT_ANSWER"): try: - logger.debug("MULTI_COLLECTION_ONE_FAILED", extra={"collection": _c}) + logger.debug( + "MULTI_COLLECTION_ONE_FAILED", + extra={"collection": _c}, + ) except Exception: pass if _agg: @@ -4095,11 +4674,25 @@ def _ok_lang(it: Dict[str, Any]) -> bool: if _k[0] and _k not in _seen: _seen.add(_k) _ded.append(_it) - _ded.sort(key=lambda x: float(x.get("score") or x.get("fusion_score") or x.get("raw_score") or 0.0), reverse=True) + _ded.sort( + key=lambda x: float( + x.get("score") + or x.get("fusion_score") + or x.get("raw_score") + or 0.0 + ), + reverse=True, + ) items = (items or []) + _ded[: int(max(lim, 4))] if os.environ.get("DEBUG_CONTEXT_ANSWER"): try: - logger.debug("MULTI_COLLECTION", extra={"count": len(_ded), "first": (_ded[0].get("path") if _ded else None)}) + logger.debug( + "MULTI_COLLECTION", + extra={ + "count": len(_ded), + "first": (_ded[0].get("path") if _ded else None), + }, + ) except Exception: pass except Exception: @@ -4107,22 +4700,43 @@ def _ok_lang(it: Dict[str, Any]) -> bool: logger.debug("MULTI_COLLECTION_FAIL", exc_info=True) # Doc-aware retrieval pass: pull READMEs/docs when results are thin (index-only) try: - _doc_enabled = str(os.environ.get("CTX_DOC_PASS", "1")).strip().lower() in {"1","true","yes","on"} + _doc_enabled = str(os.environ.get("CTX_DOC_PASS", "1")).strip().lower() in { + "1", + "true", + "yes", + "on", + } _qtext = " ".join([q for q in (queries or []) if isinstance(q, str)]).lower() - _broad_tokens = ("how", "explain", "overview", "architecture", "design", "work", "works", "guide", "readme") + _broad_tokens = ( + "how", + "explain", + "overview", + "architecture", + "design", + "work", + "works", + "guide", + "readme", + ) _looks_broad = any(t in _qtext for t in _broad_tokens) _pre_doc_len = len(items or []) - # Consider docs pass when results are thin OR the query looks broad if _doc_enabled and ((len(items) < max(3, int(lim) // 2)) or _looks_broad): # Skip if the user provided strict filters; this is for broad prompts _doc_strict_filters = bool( - eff_language or eff_path_glob or path_regex or sym_arg or ext or kind or override_under + eff_language + or eff_path_glob + or path_regex + or sym_arg + or ext + or kind + or override_under ) if not _doc_strict_filters: from scripts.hybrid_search import run_hybrid_search as _rhs # type: ignore + _doc_globs = [ "**/README*", "README*", @@ -4137,25 +4751,36 @@ def _ok_lang(it: Dict[str, Any]) -> bool: "**/*.txt", "**/*.adoc", ] - _doc_results = _rhs( - queries=queries, - limit=int(max(lim, 8)), - per_path=int(max(ppath, 2)), - language=None, - under=override_under or None, - kind=None, - symbol=None, - ext=None, - not_filter=not_ or None, - case=case or None, - path_regex=None, - path_glob=_doc_globs, - not_glob=eff_not_glob, - expand=str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() in {"1","true","yes","on"}, - model=model, - ) or [] + _doc_results = ( + _rhs( + queries=queries, + limit=int(max(lim, 8)), + per_path=int(max(ppath, 2)), + language=None, + under=override_under or None, + kind=None, + symbol=None, + ext=None, + not_filter=not_ or None, + case=case or None, + path_regex=None, + path_glob=_doc_globs, + not_glob=eff_not_glob, + expand=str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() + in {"1", "true", "yes", "on"}, + model=model, + ) + or [] + ) if _doc_results: - _seen = set((str(it.get("path") or ""), int(it.get("start_line") or 0), int(it.get("end_line") or 0)) for it in (items or [])) + _seen = set( + ( + str(it.get("path") or ""), + int(it.get("start_line") or 0), + int(it.get("end_line") or 0), + ) + for it in (items or []) + ) _merged = [] for it in _doc_results: if not isinstance(it, dict): @@ -4169,52 +4794,106 @@ def _ok_lang(it: Dict[str, Any]) -> bool: _seen.add(_k) _merged.append(it) # Prefer highest scoring doc snippets, but cap to avoid crowding out code spans - _merged.sort(key=lambda x: float(x.get("score") or x.get("fusion_score") or x.get("raw_score") or 0.0), reverse=True) + _merged.sort( + key=lambda x: float( + x.get("score") + or x.get("fusion_score") + or x.get("raw_score") + or 0.0 + ), + reverse=True, + ) _cap = max(2, int(lim) // 2) items = (items or []) + _merged[:_cap] if os.environ.get("DEBUG_CONTEXT_ANSWER"): try: - logger.debug("DOC_PASS", extra={"count": len(_merged), "first": (_merged[0].get("path") if _merged else None)}) + logger.debug( + "DOC_PASS", + extra={ + "count": len(_merged), + "first": ( + _merged[0].get("path") if _merged else None + ), + }, + ) except Exception: pass # If broad prompt and doc pass added nothing, try top-docs fallback try: - _doc_top_enabled = str(os.environ.get("CTX_DOC_TOP_FALLBACK", "1")).strip().lower() in {"1","true","yes","on"} - if _doc_top_enabled and _looks_broad and len(items or []) == _pre_doc_len: + _doc_top_enabled = str( + os.environ.get("CTX_DOC_TOP_FALLBACK", "1") + ).strip().lower() in {"1", "true", "yes", "on"} + if ( + _doc_top_enabled + and _looks_broad + and len(items or []) == _pre_doc_len + ): _fallback_qs = ["overview", "architecture", "readme"] - _top = _rhs( - queries=_fallback_qs, - limit=int(max(lim, 6)), - per_path=int(max(ppath, 2)), - language=None, - under=override_under or None, - kind=None, - symbol=None, - ext=None, - not_filter=not_ or None, - case=case or None, - path_regex=None, - path_glob=_doc_globs, - not_glob=eff_not_glob, - expand=False, - model=model, - ) or [] + _top = ( + _rhs( + queries=_fallback_qs, + limit=int(max(lim, 6)), + per_path=int(max(ppath, 2)), + language=None, + under=override_under or None, + kind=None, + symbol=None, + ext=None, + not_filter=not_ or None, + case=case or None, + path_regex=None, + path_glob=_doc_globs, + not_glob=eff_not_glob, + expand=False, + model=model, + ) + or [] + ) if _top: - _seen2 = set((str(it.get("path") or ""), int(it.get("start_line") or 0), int(it.get("end_line") or 0)) for it in (items or [])) + _seen2 = set( + ( + str(it.get("path") or ""), + int(it.get("start_line") or 0), + int(it.get("end_line") or 0), + ) + for it in (items or []) + ) _merged2 = [] for it in _top: if not isinstance(it, dict): continue - _k = (str(it.get("path") or ""), int(it.get("start_line") or 0), int(it.get("end_line") or 0)) + _k = ( + str(it.get("path") or ""), + int(it.get("start_line") or 0), + int(it.get("end_line") or 0), + ) if _k[0] and _k not in _seen2: _seen2.add(_k) _merged2.append(it) - _merged2.sort(key=lambda x: float(x.get("score") or x.get("fusion_score") or x.get("raw_score") or 0.0), reverse=True) + _merged2.sort( + key=lambda x: float( + x.get("score") + or x.get("fusion_score") + or x.get("raw_score") + or 0.0 + ), + reverse=True, + ) _cap2 = max(1, min(2, int(lim) // 3)) items = (items or []) + _merged2[:_cap2] if os.environ.get("DEBUG_CONTEXT_ANSWER"): try: - logger.debug("DOC_TOP_FALLBACK", extra={"count": len(_merged2), "first": (_merged2[0].get("path") if _merged2 else None)}) + logger.debug( + "DOC_TOP_FALLBACK", + extra={ + "count": len(_merged2), + "first": ( + _merged2[0].get("path") + if _merged2 + else None + ), + }, + ) except Exception: pass except Exception: @@ -4236,13 +4915,22 @@ def _ok_lang(it: Dict[str, Any]) -> bool: or override_under ) # If Tier-1 and Tier-2 yielded nothing, do a tiny filesystem scan as a last resort - if (not items) and not did_local_expand and not _strict_filters and str(os.environ.get("CTX_TIER3_FS", "0")).strip().lower() in {"1","true","yes","on"}: + if ( + (not items) + and not did_local_expand + and not _strict_filters + and str(os.environ.get("CTX_TIER3_FS", "0")).strip().lower() + in {"1", "true", "yes", "on"} + ): try: import re as _re + primary = _primary_identifier_from_queries(queries) if primary and len(primary) >= 3: if os.environ.get("DEBUG_CONTEXT_ANSWER"): - logger.debug("TIER3: filesystem scan", extra={"identifier": primary}) + logger.debug( + "TIER3: filesystem scan", extra={"identifier": primary} + ) scan_root = override_under or cwd_root if not os.path.isabs(scan_root): scan_root = os.path.join(cwd_root, scan_root) @@ -4253,12 +4941,15 @@ def _ok_lang(it: Dict[str, Any]) -> bool: dirs[:] = [ d for d in dirs - if not any(ex in d for ex in [ - ".git", - "node_modules", - ".pytest_cache", - "__pycache__", - ]) + if not any( + ex in d + for ex in [ + ".git", + "node_modules", + ".pytest_cache", + "__pycache__", + ] + ) ] for fname in files: if scanned >= max_files: @@ -4280,18 +4971,26 @@ def _ok_lang(it: Dict[str, Any]) -> bool: continue fpath = os.path.join(root, fname) try: - with open(fpath, "r", encoding="utf-8", errors="ignore") as f: + with open( + fpath, "r", encoding="utf-8", errors="ignore" + ) as f: lines = f.readlines() scanned += 1 for idx, line in enumerate(lines, 1): - if _re.search(rf"\b{_re.escape(primary)}\b\s*[=:(]", line): + if _re.search( + rf"\b{_re.escape(primary)}\b\s*[=:(]", line + ): try: rel_path = os.path.relpath(fpath, cwd_root) except ValueError: - rel_path = fpath.replace(cwd_root, "").lstrip("/\\") + rel_path = fpath.replace(cwd_root, "").lstrip( + "/\\" + ) snippet_start = max(1, idx - 2) snippet_end = min(len(lines), idx + 3) - snippet_text = "".join(lines[snippet_start - 1 : snippet_end]) + snippet_text = "".join( + lines[snippet_start - 1 : snippet_end] + ) ext_map = { ".py": "python", ".js": "javascript", @@ -4304,7 +5003,11 @@ def _ok_lang(it: Dict[str, Any]) -> bool: ".h": "c", } lang = next( - (v for k, v in ext_map.items() if fname.endswith(k)), + ( + v + for k, v in ext_map.items() + if fname.endswith(k) + ), "unknown", ) tier3_hits.append( @@ -4349,15 +5052,23 @@ def _ok_lang(it: Dict[str, Any]) -> bool: # Apply ReFRAG span budgeting to compress context from scripts.hybrid_search import _merge_and_budget_spans # type: ignore + try: if os.environ.get("DEBUG_CONTEXT_ANSWER"): logger.debug("BUDGET_BEFORE", extra={"items": len(items)}) _pairs = {} try: # Relax budgets for context_answer unless explicitly disabled via CTX_RELAX_BUDGETS=0 - if str(os.environ.get("CTX_RELAX_BUDGETS", "1")).strip().lower() in {"1", "true", "yes", "on"}: + if str(os.environ.get("CTX_RELAX_BUDGETS", "1")).strip().lower() in { + "1", + "true", + "yes", + "on", + }: _pairs = { - "MICRO_BUDGET_TOKENS": os.environ.get("MICRO_BUDGET_TOKENS", "1024"), + "MICRO_BUDGET_TOKENS": os.environ.get( + "MICRO_BUDGET_TOKENS", "1024" + ), "MICRO_OUT_MAX_SPANS": os.environ.get("MICRO_OUT_MAX_SPANS", "8"), } except Exception: @@ -4420,7 +5131,11 @@ def _span_haystack(span: Dict[str, Any]) -> str: parts = [ str(span.get("text") or ""), str(span.get("symbol") or ""), - str((span.get("relations") or {}).get("symbol_path") if isinstance(span.get("relations"), dict) else ""), + str( + (span.get("relations") or {}).get("symbol_path") + if isinstance(span.get("relations"), dict) + else "" + ), str(span.get("path") or ""), str(span.get("_ident_snippet") or ""), ] @@ -4494,7 +5209,10 @@ def _span_key(span: Dict[str, Any]) -> tuple[str, int, int]: if os.environ.get("DEBUG_CONTEXT_ANSWER"): logger.debug( "IDENT_AUGMENT", - extra={"candidates": len(ident_candidates), "ident": primary_ident}, + extra={ + "candidates": len(ident_candidates), + "ident": primary_ident, + }, ) source_spans = ident_candidates else: @@ -4572,7 +5290,15 @@ def _ca_build_citations_and_context( spans: list[Dict[str, Any]], include_snippet: bool, queries: list[str], -) -> tuple[list[Dict[str, Any]], list[str], dict[int, str], str | None, str, int | None, int | None]: +) -> tuple[ + list[Dict[str, Any]], + list[str], + dict[int, str], + str | None, + str, + int | None, + int | None, +]: """Build citations, read snippets, assemble context blocks, and extract def/usage hints. Returns (citations, context_blocks, snippets_by_id, asked_ident, def_line_exact, def_id, usage_id). """ @@ -4591,16 +5317,25 @@ def _ca_build_citations_and_context( eline = int(it.get("end_line") or 0) _hostp = it.get("host_path") _contp = it.get("container_path") + # Provide both container-absolute and repo-relative forms for compatibility + def _norm(p: str) -> str: + try: + if p.startswith("/work/"): + return p[len("/work/"):] + return p.lstrip("/") if p.startswith("/work") else p + except Exception: + return p _cit = { "id": idx, - "path": path, + "path": path, # keep original for backward compatibility (tests expect /work/...) + "rel_path": _norm(path), "start_line": sline, "end_line": eline, } if _hostp: - _cit["host_path"] = _hostp + _cit["host_path"] = _norm(str(_hostp)) if _contp: - _cit["container_path"] = _contp + _cit["container_path"] = str(_contp) citations.append(_cit) snippet = str(it.get("text") or "").strip() @@ -4610,6 +5345,7 @@ def _ca_build_citations_and_context( try: fp = path import os as _os + if not _os.path.isabs(fp): fp = _os.path.join("/work", fp) realp = _os.path.realpath(fp) @@ -4662,8 +5398,11 @@ def _ca_build_citations_and_context( try: if asked_ident and snippet: import re as _re + for _ln in str(snippet).splitlines(): - if not _def_line_exact and _re.match(rf"\s*{_re.escape(asked_ident)}\s*=", _ln): + if not _def_line_exact and _re.match( + rf"\s*{_re.escape(asked_ident)}\s*=", _ln + ): _def_line_exact = _ln.strip() _def_id = idx elif (asked_ident in _ln) and (_def_id != idx): @@ -4682,16 +5421,27 @@ def _ca_build_citations_and_context( }, ) - return citations, context_blocks, snippets_by_id, asked_ident, _def_line_exact, _def_id, _usage_id + return ( + citations, + context_blocks, + snippets_by_id, + asked_ident, + _def_line_exact, + _def_id, + _usage_id, + ) -def _ca_ident_supplement(paths: list[str], ident: str, *, include_snippet: bool, max_hits: int = 4) -> list[Dict[str, Any]]: +def _ca_ident_supplement( + paths: list[str], ident: str, *, include_snippet: bool, max_hits: int = 4 +) -> list[Dict[str, Any]]: """Lightweight FS supplement: when an identifier is asked but the retrieved spans missed its definition/usage, scan a small set of candidate files for that identifier and return minimal spans around the hits. Keeps scope tiny and safe. """ import os as _os import re as _re + out: list[Dict[str, Any]] = [] seen: set[tuple[str, int, int]] = set() ident = str(ident or "").strip() @@ -4704,7 +5454,7 @@ def _ca_ident_supplement(paths: list[str], ident: str, *, include_snippet: bool, pat_def = _re.compile(rf"\b{_re.escape(ident)}\b\s*=") pat_any = _re.compile(rf"\b{_re.escape(ident)}\b") - for p in (paths or []): + for p in paths or []: if len(out) >= max_hits: break try: @@ -4736,12 +5486,14 @@ def _ca_ident_supplement(paths: list[str], ident: str, *, include_snippet: bool, si = max(1, idx - margin) ei = min(len(lines), idx + margin) snippet = "".join(lines[si - 1 : ei]) - out.append({ - "path": p, - "start_line": idx, - "end_line": idx, - "_ident_snippet": snippet, - }) + out.append( + { + "path": p, + "start_line": idx, + "end_line": idx, + "_ident_snippet": snippet, + } + ) seen.add(key) if len(out) >= max_hits: break @@ -4757,25 +5509,40 @@ def _to_int(v, d): return int(v) except (ValueError, TypeError): return d + def _to_float(v, d): try: return float(v) except (ValueError, TypeError): return d + stop_env = os.environ.get("DECODER_STOP", "") - default_stops = ["<|end_of_text|>", "<|start_of_role|>", "<|end_of_response|>", "\n\n\n"] + default_stops = [ + "<|end_of_text|>", + "<|start_of_role|>", + "<|end_of_response|>", + "\n\n\n", + ] stops = default_stops + [s for s in (stop_env.split(",") if stop_env else []) if s] - mtok = _to_int(max_tokens, _to_int(os.environ.get("DECODER_MAX_TOKENS", "240"), 240)) + mtok = _to_int( + max_tokens, _to_int(os.environ.get("DECODER_MAX_TOKENS", "240"), 240) + ) temp = 0.0 top_k = _to_int(os.environ.get("DECODER_TOP_K", "20"), 20) top_p = _to_float(os.environ.get("DECODER_TOP_P", "0.85"), 0.85) return mtok, temp, top_k, top_p, stops -def _ca_build_prompt(context_blocks: list[str], citations: list[Dict[str, Any]], queries: list[str]) -> str: +def _ca_build_prompt( + context_blocks: list[str], citations: list[Dict[str, Any]], queries: list[str] +) -> str: qtxt = "\n".join(queries) docs_text = "\n\n".join(context_blocks) if context_blocks else "(no code found)" - sources_footer = "\n".join([f"[{c.get('id')}] {c.get('path')}" for c in citations]) if citations else "" + sources_footer = ( + "\n".join([f"[{c.get('id')}] {c.get('path')}" for c in citations]) + if citations + else "" + ) system_msg = ( "You are a helpful assistant with access to the following code snippets. " "You may use one or more snippets to assist with the user query.\n\n" @@ -4797,41 +5564,83 @@ def _ca_build_prompt(context_blocks: list[str], citations: list[Dict[str, Any]], return prompt -def _ca_decode(prompt: str, *, mtok: int, temp: float, top_k: int, top_p: float, stops: list[str]) -> str: - runtime = str(os.environ.get("REFRAG_RUNTIME", "llamacpp")).strip().lower() - if runtime == "glm": - from scripts.refrag_glm import GLMRefragClient # type: ignore - client = GLMRefragClient() - else: - from scripts.refrag_llamacpp import LlamaCppRefragClient # type: ignore - client = LlamaCppRefragClient() - return client.generate_with_soft_embeddings( - prompt=prompt, - max_tokens=mtok, - temperature=temp, - top_k=top_k, - top_p=top_p, - stop=stops, - repeat_penalty=float(os.environ.get("DECODER_REPEAT_PENALTY", "1.15") or 1.15), - repeat_last_n=int(os.environ.get("DECODER_REPEAT_LAST_N", "128") or 128), - ) +def _ca_decode( + prompt: str, *, mtok: int, temp: float, top_k: int, top_p: float, stops: list[str] +) -> str: + from scripts.refrag_llamacpp import LlamaCppRefragClient # type: ignore + + client = LlamaCppRefragClient() + base_tokens = int(max(16, mtok)) + last_err: Optional[Exception] = None + import time as _time + for attempt in range(3): + # Gradually reduce token budget on retries + cur_tokens = ( + base_tokens if attempt == 0 else max(16, base_tokens // (2 if attempt == 1 else 3)) + ) + try: + return client.generate_with_soft_embeddings( + prompt=prompt, + max_tokens=cur_tokens, + temperature=temp, + top_k=top_k, + top_p=top_p, + stop=stops, + repeat_penalty=float(os.environ.get("DECODER_REPEAT_PENALTY", "1.15") or 1.15), + repeat_last_n=int(os.environ.get("DECODER_REPEAT_LAST_N", "128") or 128), + ) + except Exception as e: + last_err = e + # Allow quick retries with reduced budget and tiny backoff to rescue transient 5xx + if attempt < 2: + _time.sleep(0.2 * (attempt + 1)) + continue + raise + if last_err: + raise last_err + raise RuntimeError("decoder call failed without explicit error") -def _ca_postprocess_answer(answer: str, citations: list[Dict[str, Any]], *, asked_ident: str | None = None, - def_line_exact: str | None = None, def_id: int | None = None, - usage_id: int | None = None, snippets_by_id: dict[int, str] | None = None) -> str: +def _ca_postprocess_answer( + answer: str, + citations: list[Dict[str, Any]], + *, + asked_ident: str | None = None, + def_line_exact: str | None = None, + def_id: int | None = None, + usage_id: int | None = None, + snippets_by_id: dict[int, str] | None = None, +) -> str: import re as _re + snippets_by_id = snippets_by_id or {} txt = (answer or "").strip() # Strip leaked stop tokens for stop_tok in ["<|end_of_text|>", "<|start_of_role|>", "<|end_of_response|>"]: txt = txt.replace(stop_tok, "") + # Remove accidental URLs/Markdown links; enforce bracket citations only + import re as _re + txt = _re.sub(r"https?://\S+", "", txt) + # Convert Markdown links [text](url) or even incomplete [text]( to [text] + txt = _re.sub(r"\[([^\]]+)\]\s*\([^\)]*\)?", r"[\1]", txt) # Cleanup repetition - txt = _cleanup_answer(txt, max_chars=(safe_int(os.environ.get("CTX_SUMMARY_CHARS", ""), default=0, logger=logger, context="CTX_SUMMARY_CHARS") or None)) + txt = _cleanup_answer( + txt, + max_chars=( + safe_int( + os.environ.get("CTX_SUMMARY_CHARS", ""), + default=0, + logger=logger, + context="CTX_SUMMARY_CHARS", + ) + or None + ), + ) # Strict two-line (optional via env); otherwise remove labels and keep concise try: - def_part = ""; usage_part = "" + def_part = "" + usage_part = "" if "Usage:" in txt: parts = txt.split("Usage:", 1) def_part = parts[0] @@ -4848,7 +5657,11 @@ def _fmt_citation(cid: int | None) -> str: def_line = None if asked_ident and def_line_exact: - cid = def_id if (def_id is not None) else (citations[0]["id"] if citations else None) + cid = ( + def_id + if (def_id is not None) + else (citations[0]["id"] if citations else None) + ) def_line = f'Definition: "{def_line_exact}"{_fmt_citation(cid)}' else: cand = def_part.strip().strip("\n ") @@ -4856,9 +5669,6 @@ def _fmt_citation(cid: int | None) -> str: cand = "" m = _re.search(r'"([^"]+)"', cand) - - - q = m.group(1) if m else cand if asked_ident and asked_ident in q: cid = citations[0]["id"] if citations else None @@ -4866,7 +5676,8 @@ def _fmt_citation(cid: int | None) -> str: if not def_line: def_line = "Definition: Not found in provided snippets." - usage_text = ""; usage_cid: int | None = None + usage_text = "" + usage_cid: int | None = None try: if asked_ident and (usage_id is not None): _sn = snippets_by_id.get(usage_id) or "" @@ -4875,25 +5686,51 @@ def _fmt_citation(cid: int | None) -> str: if _re.match(rf"\s*{_re.escape(asked_ident)}\s*=", _ln): continue if asked_ident in _ln: - usage_text = _ln.strip(); usage_cid = usage_id; break + usage_text = _ln.strip() + usage_cid = usage_id + break except Exception: - usage_text = ""; usage_cid = None + usage_text = "" + usage_cid = None if not usage_text: usage_text = usage_part.strip().replace("\n", " ") if usage_part else "" usage_text = _re.sub(r"\s+", " ", usage_text).strip() if not usage_text: if usage_id is not None: - usage_text = "Appears in the shown code."; usage_cid = usage_id + usage_text = "Appears in the shown code." + usage_cid = usage_id else: - usage_text = "Not found in provided snippets."; usage_cid = def_id if (def_id is not None) else (citations[0]["id"] if citations else None) + usage_text = "Not found in provided snippets." + usage_cid = ( + def_id + if (def_id is not None) + else (citations[0]["id"] if citations else None) + ) if "[" not in usage_text and "]" not in usage_text: - uid = usage_cid if (usage_cid is not None) else (usage_id if (usage_id is not None) else (def_id if (def_id is not None) else (citations[0]["id"] if citations else None))) + uid = ( + usage_cid + if (usage_cid is not None) + else ( + usage_id + if (usage_id is not None) + else ( + def_id + if (def_id is not None) + else (citations[0]["id"] if citations else None) + ) + ) + ) usage_line = f"Usage: {usage_text}{_fmt_citation(uid)}" else: usage_line = f"Usage: {usage_text}" - if str(os.environ.get("CTX_ENFORCE_TWO_LINES", "0")).strip().lower() in {"1", "true", "yes", "on"}: + if str(os.environ.get("CTX_ENFORCE_TWO_LINES", "0")).strip().lower() in { + "1", + "true", + "yes", + "on", + }: txt = f"{def_line}\n{usage_line}".strip() else: txt = _strip_preamble_labels(txt) @@ -4930,7 +5767,6 @@ def _fmt_citation(cid: int | None) -> str: return txt - def _synthesize_from_citations( *, asked_ident: str | None, @@ -4944,6 +5780,7 @@ def _synthesize_from_citations( Returns 1–2 short lines with inline bracket citations when possible. """ import re as _re + snippets_by_id = snippets_by_id or {} def _fmt(cid: int | None) -> str: @@ -4954,7 +5791,11 @@ def _fmt(cid: int | None) -> str: # Prefer a definition-style line when an identifier is asked if asked_ident: if def_line_exact: - cid = def_id if (def_id is not None) else (citations[0].get("id") if citations else None) + cid = ( + def_id + if (def_id is not None) + else (citations[0].get("id") if citations else None) + ) lines.append(f'Definition: "{def_line_exact}"{_fmt(cid)}') else: # Try to harvest a definition-like line from snippets @@ -4974,12 +5815,17 @@ def _fmt(cid: int | None) -> str: lines.append(f'Definition: "{best_line}"{_fmt(best_cid)}') # Usage line when possible - use_line = ""; use_cid: int | None = None + use_line = "" + use_cid: int | None = None if usage_id is not None: sn = snippets_by_id.get(int(usage_id), "") or "" for ln in sn.splitlines(): - if asked_ident in ln and not _re.match(rf"\s*{_re.escape(asked_ident)}\s*=", ln): - use_line = ln.strip(); use_cid = usage_id; break + if asked_ident in ln and not _re.match( + rf"\s*{_re.escape(asked_ident)}\s*=", ln + ): + use_line = ln.strip() + use_cid = usage_id + break if not use_line: # fall back to first citation line mentioning the ident for c in citations: @@ -4987,7 +5833,9 @@ def _fmt(cid: int | None) -> str: sn = snippets_by_id.get(int(sid) if sid is not None else -1) or "" for ln in sn.splitlines(): if asked_ident in ln: - use_line = ln.strip(); use_cid = sid; break + use_line = ln.strip() + use_cid = sid + break if use_line: break if use_line: @@ -5036,7 +5884,7 @@ async def context_answer( not_glob: Any = None, case: Any = None, not_: Any = None, - **kwargs, + kwargs: Any = None, ) -> Dict[str, Any]: """Natural-language Q&A over the repo using retrieval + local LLM (llama.cpp). @@ -5068,10 +5916,27 @@ async def context_answer( """ # Normalize inputs and compute effective limits/flags _cfg = _ca_unwrap_and_normalize( - query, limit, per_path, budget_tokens, include_snippet, collection, - - max_tokens, temperature, mode, expand, language, under, kind, symbol, - ext, path_regex, path_glob, not_glob, case, not_, kwargs, + query, + limit, + per_path, + budget_tokens, + include_snippet, + collection, + max_tokens, + temperature, + mode, + expand, + language, + under, + kind, + symbol, + ext, + path_regex, + path_glob, + not_glob, + case, + not_, + kwargs, ) queries = _cfg["queries"] lim = _cfg["limit"] @@ -5094,17 +5959,43 @@ async def context_answer( not_glob = _flt.get("not_glob") case = _flt.get("case") not_ = _flt.get("not_") + # Enforce sane minimums to avoid empty span selection + try: + lim = int(lim) + except Exception: + lim = 15 + if lim <= 0: + lim = 1 + try: + ppath = int(ppath) + except Exception: + ppath = 5 + if ppath <= 0: + ppath = 1 # Soft per-call deadline to avoid client-side 60s timeouts _ca_start_ts = time.time() if os.environ.get("DEBUG_CONTEXT_ANSWER"): - logger.debug("ARG_SHAPE", extra={"normalized_queries": queries, "limit": lim, "per_path": ppath}) + logger.debug( + "ARG_SHAPE", + extra={"normalized_queries": queries, "limit": lim, "per_path": ppath}, + ) # Broad-query budget bump (gated). If user didn't pass budget, scale env default; else scale provided value. try: _qtext = " ".join([q for q in (queries or []) if isinstance(q, str)]).lower() - _broad_tokens = ("how","explain","overview","architecture","design","work","works","guide","readme") + _broad_tokens = ( + "how", + "explain", + "overview", + "architecture", + "design", + "work", + "works", + "guide", + "readme", + ) _broad = any(t in _qtext for t in _broad_tokens) except Exception: _broad = False @@ -5132,8 +6023,10 @@ async def context_answer( model = _get_embedding_model(model_name) # Prepare environment toggles for ReFRAG gate-first and budgeting - # Acquire lock to avoid cross-request env clobbering - _ENV_LOCK.acquire() + # Acquire lock to avoid cross-request env clobbering (with timeout) + if not _ENV_LOCK.acquire(timeout=30.0): + logger.warning("ENV_LOCK timeout, potential deadlock detected") + # Continue anyway to avoid complete deadlock, but log the issue prev = { "REFRAG_MODE": os.environ.get("REFRAG_MODE"), "REFRAG_GATE_FIRST": os.environ.get("REFRAG_GATE_FIRST"), @@ -5145,7 +6038,9 @@ async def context_answer( try: # Enable ReFRAG gate-first for context compression os.environ["REFRAG_MODE"] = "1" - os.environ["REFRAG_GATE_FIRST"] = os.environ.get("REFRAG_GATE_FIRST", "1") or "1" + os.environ["REFRAG_GATE_FIRST"] = ( + os.environ.get("REFRAG_GATE_FIRST", "1") or "1" + ) os.environ["COLLECTION_NAME"] = coll if budget_tokens is not None and str(budget_tokens).strip() != "": os.environ["MICRO_BUDGET_TOKENS"] = str(budget_tokens) @@ -5154,15 +6049,26 @@ async def context_answer( # For LLM answering, default to include snippets so the model sees actual code if include_snippet in (None, ""): include_snippet = True - did_local_expand = False # Ensure defined even if expansion is disabled or fails - + did_local_expand = ( + False # Ensure defined even if expansion is disabled or fails + ) - do_expand = safe_bool(expand, default=False, logger=logger, context="expand") or \ - safe_bool(os.environ.get("HYBRID_EXPAND", "0"), default=False, logger=logger, context="HYBRID_EXPAND") + do_expand = safe_bool( + expand, default=False, logger=logger, context="expand" + ) or safe_bool( + os.environ.get("HYBRID_EXPAND", "0"), + default=False, + logger=logger, + context="HYBRID_EXPAND", + ) if do_expand: try: - from scripts.refrag_llamacpp import LlamaCppRefragClient, is_decoder_enabled # type: ignore + from scripts.refrag_llamacpp import ( + LlamaCppRefragClient, + is_decoder_enabled, + ) # type: ignore + if is_decoder_enabled(): prompt = ( "You expand code search queries. Given one or more short queries, " @@ -5172,13 +6078,15 @@ async def context_answer( client = LlamaCppRefragClient() # tight decoding for expansions out = client.generate_with_soft_embeddings( - prompt=prompt, max_tokens=int(os.environ.get("EXPAND_MAX_TOKENS", "64") or 64), + prompt=prompt, + max_tokens=int(os.environ.get("EXPAND_MAX_TOKENS", "64") or 64), temperature=0.0, # Always 0 for deterministic expansion top_k=int(os.environ.get("EXPAND_TOP_K", "30") or 30), top_p=float(os.environ.get("EXPAND_TOP_P", "0.9") or 0.9), - stop=["\n\n"] + stop=["\n\n"], ) import json as _json + alts = [] try: parsed = _json.loads(out) @@ -5188,7 +6096,7 @@ async def context_answer( start = out.find("[") end = out.rfind("]") if start != -1 and end != -1 and end > start: - parsed = _json.loads(out[start:end+1]) + parsed = _json.loads(out[start : end + 1]) else: parsed = [] except Exception as e2: @@ -5202,16 +6110,24 @@ async def context_answer( break if not alts and out and out.strip(): # Heuristic fallback: split lines, trim bullets, take up to 2 - for cand in [t.strip().lstrip("-β€’ ") for t in out.splitlines() if t.strip()][:2]: + for cand in [ + t.strip().lstrip("-β€’ ") + for t in out.splitlines() + if t.strip() + ][:2]: if cand and cand not in queries and len(alts) < 2: alts.append(cand) if alts: queries.extend(alts) did_local_expand = True # Mark that we already expanded except (ImportError, AttributeError) as e: - logger.warning("Query expansion failed (decoder unavailable)", exc_info=e) + logger.warning( + "Query expansion failed (decoder unavailable)", exc_info=e + ) except (TimeoutError, ConnectionError) as e: - logger.warning("Query expansion failed (decoder timeout/connection)", exc_info=e) + logger.warning( + "Query expansion failed (decoder timeout/connection)", exc_info=e + ) except Exception as e: logger.error("Unexpected error during query expansion", exc_info=e) @@ -5293,18 +6209,23 @@ def _to_glob_list(val: Any) -> list[str]: if v is None: try: del os.environ[k] - except Exception: - pass + except Exception as e: + logger.error(f"Failed to restore env var {k}: {e}") else: os.environ[k] = v _ENV_LOCK.release() if err is not None: - return {"error": f"hybrid search failed: {err}", "citations": [], "query": queries} + return { + "error": f"hybrid search failed: {err}", + "citations": [], + "query": queries, + } # Ensure final retrieval call reflects Tier-2 relaxed filters for tests/introspection try: from scripts.hybrid_search import run_hybrid_search as _rh # type: ignore + _ = _rh( queries=queries, limit=int(max(lim, 1)), @@ -5314,17 +6235,59 @@ def _to_glob_list(val: Any) -> list[str]: pass # Build citations and context payload for the decoder - citations, context_blocks, snippets_by_id, asked_ident, _def_line_exact, _def_id, _usage_id = _ca_build_citations_and_context( + ( + citations, + context_blocks, + snippets_by_id, + asked_ident, + _def_line_exact, + _def_id, + _usage_id, + ) = _ca_build_citations_and_context( spans=spans, include_snippet=bool(include_snippet), queries=queries, ) + # Salvage: if citations are empty but we have items, rebuild from raw items + if not citations: + try: + ( + citations2, + context_blocks2, + snippets_by_id2, + asked_ident2, + _def_line_exact2, + _def_id2, + _usage_id2, + ) = _ca_build_citations_and_context( + spans=(items or []), + include_snippet=bool(include_snippet), + queries=queries, + ) + if citations2: + citations = citations2 + context_blocks = context_blocks2 + snippets_by_id = snippets_by_id2 + asked_ident = asked_ident2 + _def_line_exact = _def_line_exact2 + _def_id = _def_id2 + _usage_id = _usage_id2 + except Exception: + pass + # If still no citations, return an explicit insufficient-context answer + if not citations: + return { + "answer": "insufficient context", + "citations": [], + "query": queries, + "used": {"gate_first": True, "refrag": True, "no_citations": True}, + } # If an identifier was asked and we didn't capture its definition yet, # do a tiny FS supplement over candidate paths (from retrieved items and explicit filename in query). if asked_ident and not _def_line_exact: cand_paths: list[str] = [] - for it in (items or []): + for it in items or []: p = it.get("path") or it.get("host_path") or it.get("container_path") if p and str(p) not in cand_paths: cand_paths.append(str(p)) @@ -5332,6 +6295,7 @@ def _to_glob_list(val: Any) -> list[str]: try: qj3 = " ".join(queries) import re as _re + m = _re.search(r"in\s+([\w./-]+\.py)\b", qj3) if m: fp = m.group(1) @@ -5340,12 +6304,27 @@ def _to_glob_list(val: Any) -> list[str]: except Exception: pass supplements = [] - if str(os.environ.get("CTX_TIER3_FS", "0")).strip().lower() in {"1","true","yes","on"}: - supplements = _ca_ident_supplement(cand_paths, asked_ident, include_snippet=bool(include_snippet), max_hits=3) + if str(os.environ.get("CTX_TIER3_FS", "0")).strip().lower() in { + "1", + "true", + "yes", + "on", + }: + supplements = _ca_ident_supplement( + cand_paths, + asked_ident, + include_snippet=bool(include_snippet), + max_hits=3, + ) if supplements: # Prepend supplements so the decoder sees them first def _k(s: Dict[str, Any]): - return (str(s.get("path") or ""), int(s.get("start_line") or 0), int(s.get("end_line") or 0)) + return ( + str(s.get("path") or ""), + int(s.get("start_line") or 0), + int(s.get("end_line") or 0), + ) + seen_keys = {_k(s) for s in spans} new_spans = [] for s in supplements: @@ -5355,7 +6334,15 @@ def _k(s: Dict[str, Any]): seen_keys.add(k) if new_spans: spans = new_spans + spans - citations, context_blocks, snippets_by_id, asked_ident, _def_line_exact, _def_id, _usage_id = _ca_build_citations_and_context( + ( + citations, + context_blocks, + snippets_by_id, + asked_ident, + _def_line_exact, + _def_id, + _usage_id, + ) = _ca_build_citations_and_context( spans=spans, include_snippet=bool(include_snippet), queries=queries, @@ -5363,9 +6350,14 @@ def _k(s: Dict[str, Any]): # Debug: log span details if os.environ.get("DEBUG_CONTEXT_ANSWER"): - logger.debug("CONTEXT_BLOCKS", extra={"spans": len(spans), "context_blocks": len(context_blocks), "previews": [block[:300] for block in context_blocks[:3]]}) - - + logger.debug( + "CONTEXT_BLOCKS", + extra={ + "spans": len(spans), + "context_blocks": len(context_blocks), + "previews": [block[:300] for block in context_blocks[:3]], + }, + ) # Stop sequences for Granite-4.0-Micro + optional env overrides stop_env = os.environ.get("DECODER_STOP", "") @@ -5377,10 +6369,10 @@ def _k(s: Dict[str, Any]): ] stops = default_stops + [s for s in (stop_env.split(",") if stop_env else []) if s] - # Ensure the last retrieval call reflects Tier-2 relaxed filters for tests/introspection try: from scripts.hybrid_search import run_hybrid_search as _rhs # type: ignore + _ = _rhs( queries=queries, limit=1, @@ -5395,17 +6387,41 @@ def _k(s: Dict[str, Any]): path_regex=None, path_glob=None, not_glob=eff_not_glob, - expand=False if did_local_expand else (str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() in {"1","true","yes","on"}), + expand=False + if did_local_expand + else ( + str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() + in {"1", "true", "yes", "on"} + ), model=model, ) except Exception: pass # Deadline-aware decode budgeting - _client_deadline_sec = safe_float(os.environ.get("CTX_CLIENT_DEADLINE_SEC", "178"), default=178.0, logger=logger, context="CTX_CLIENT_DEADLINE_SEC") - _tokens_per_sec = safe_float(os.environ.get("DECODER_TOKENS_PER_SEC", ""), default=10.0, logger=logger, context="DECODER_TOKENS_PER_SEC") - _decoder_timeout_cap = safe_float(os.environ.get("CTX_DECODER_TIMEOUT_CAP", "170"), default=170.0, logger=logger, context="CTX_DECODER_TIMEOUT_CAP") - _deadline_margin = safe_float(os.environ.get("CTX_DEADLINE_MARGIN_SEC", "6"), default=6.0, logger=logger, context="CTX_DEADLINE_MARGIN_SEC") - + _client_deadline_sec = safe_float( + os.environ.get("CTX_CLIENT_DEADLINE_SEC", "178"), + default=178.0, + logger=logger, + context="CTX_CLIENT_DEADLINE_SEC", + ) + _tokens_per_sec = safe_float( + os.environ.get("DECODER_TOKENS_PER_SEC", ""), + default=10.0, + logger=logger, + context="DECODER_TOKENS_PER_SEC", + ) + _decoder_timeout_cap = safe_float( + os.environ.get("CTX_DECODER_TIMEOUT_CAP", "170"), + default=170.0, + logger=logger, + context="CTX_DECODER_TIMEOUT_CAP", + ) + _deadline_margin = safe_float( + os.environ.get("CTX_DEADLINE_MARGIN_SEC", "6"), + default=6.0, + logger=logger, + context="CTX_DEADLINE_MARGIN_SEC", + ) # Decoder params and stops mtok, temp, top_k, top_p, stops = _ca_decoder_params(max_tokens) @@ -5413,8 +6429,11 @@ def _k(s: Dict[str, Any]): # Call llama.cpp decoder (requires REFRAG_DECODER=1) try: from scripts.refrag_llamacpp import is_decoder_enabled # type: ignore + if not is_decoder_enabled(): - logger.info("Decoder disabled; returning extractive fallback with citations") + logger.info( + "Decoder disabled; returning extractive fallback with citations" + ) _fallback_txt = _ca_postprocess_answer( "", citations, @@ -5433,12 +6452,16 @@ def _k(s: Dict[str, Any]): } # SIMPLE APPROACH: One LLM call with all context - all_context = "\n\n".join(context_blocks) if context_blocks else "(no code found)" + all_context = ( + "\n\n".join(context_blocks) if context_blocks else "(no code found)" + ) # Derive lightweight usage hint heuristics to anchor tiny models extra_hint = "" try: - if ("def rrf(" in all_context) and ("/(k + rank)" in all_context or "/ (k + rank)" in all_context): + if ("def rrf(" in all_context) and ( + "/(k + rank)" in all_context or "/ (k + rank)" in all_context + ): extra_hint = "RRF (Reciprocal Rank Fusion) formula 1.0 / (k + rank); parameter k defaults to RRF_K in def rrf." except Exception: extra_hint = "" @@ -5468,14 +6491,26 @@ def _k(s: Dict[str, Any]): } # Tighten max_tokens and decoder HTTP timeout to fit remaining time try: - _allow_tokens = int(max(16.0, min(float(mtok), max(0.0, _remain - max(0.0, float(_deadline_margin) - 2.0)) * float(_tokens_per_sec)))) + _allow_tokens = int( + max( + 16.0, + min( + float(mtok), + max(0.0, _remain - max(0.0, float(_deadline_margin) - 2.0)) + * float(_tokens_per_sec), + ), + ) + ) except Exception: _allow_tokens = int(max(16, int(mtok))) mtok = int(_allow_tokens) _llama_timeout = int( - max(5.0, min(_decoder_timeout_cap, max(1.0, _remain - 1.0)))) + max(5.0, min(_decoder_timeout_cap, max(1.0, _remain - 1.0))) + ) with _env_overrides({"LLAMACPP_TIMEOUT_SEC": str(_llama_timeout)}): - answer = _ca_decode(prompt, mtok=mtok, temp=temp, top_k=top_k, top_p=top_p, stops=stops) + answer = _ca_decode( + prompt, mtok=mtok, temp=temp, top_k=top_k, top_p=top_p, stops=stops + ) # Post-process and validate answer = _ca_postprocess_answer( @@ -5489,11 +6524,16 @@ def _k(s: Dict[str, Any]): ) except Exception as e: - return {"error": f"decoder call failed: {e}", "citations": citations, "query": queries} + return { + "error": f"decoder call failed: {e}", + "citations": citations, + "query": queries, + } # Final introspection call to ensure last search reflects relaxed filters try: from scripts.hybrid_search import run_hybrid_search as _rh2 # type: ignore + _ = _rh2( queries=queries, limit=int(max(lim, 1)), @@ -5510,12 +6550,20 @@ def _k(s: Dict[str, Any]): def _tok2(s: str) -> list[str]: try: - return [w.lower() for w in _re.split(r"[^A-Za-z0-9_]+", str(s or "")) if len(w) >= 3] + return [ + w.lower() + for w in _re.split(r"[^A-Za-z0-9_]+", str(s or "")) + if len(w) >= 3 + ] except Exception: return [] # Build quick lookups from the combined retrieval we already computed - id_to_cit = {int(c.get("id") or 0): c for c in (citations or []) if int(c.get("id") or 0) > 0} + id_to_cit = { + int(c.get("id") or 0): c + for c in (citations or []) + if int(c.get("id") or 0) > 0 + } id_to_block = {idx + 1: blk for idx, blk in enumerate(context_blocks or [])} answers_by_query = [] @@ -5529,19 +6577,41 @@ def _tok2(s: str) -> list[str]: sn = (snippets_by_id.get(cid) or "").lower() if any(t in sn or t in path_l for t in toks): picked_ids.append(cid) - if len(picked_ids) >= 6: # small cap per query to keep prompt compact + if ( + len(picked_ids) >= 6 + ): # small cap per query to keep prompt compact break # Fallback if nothing matched: take the first 2 citations if not picked_ids: - picked_ids = [c.get("id") for c in (citations or [])[:2] if c.get("id")] + picked_ids = [ + c.get("id") for c in (citations or [])[:2] if c.get("id") + ] # Assemble per-query citations and context blocks using the shared retrieval cits_i = [id_to_cit[cid] for cid in picked_ids if cid in id_to_cit] - ctx_blocks_i = [id_to_block[cid] for cid in picked_ids if cid in id_to_block] - + ctx_blocks_i = [ + id_to_block[cid] for cid in picked_ids if cid in id_to_block + ] + # If we still have no citations for this query, bail early + if not cits_i: + answers_by_query.append( + { + "query": q, + "answer": "insufficient context", + "citations": [], + } + ) + continue # Decode per-query with the subset of shared context prompt_i = _ca_build_prompt(ctx_blocks_i, cits_i, [q]) - ans_raw_i = _ca_decode(prompt_i, mtok=mtok, temp=temp, top_k=top_k, top_p=top_p, stops=stops) + ans_raw_i = _ca_decode( + prompt_i, + mtok=mtok, + temp=temp, + top_k=top_k, + top_p=top_p, + stops=stops, + ) # Minimal post-processing with per-query identifier inference asked_ident_i = _primary_identifier_from_queries([q]) @@ -5552,21 +6622,27 @@ def _tok2(s: str) -> list[str]: def_line_exact=None, def_id=None, usage_id=None, - snippets_by_id={cid: snippets_by_id.get(cid, "") for cid in picked_ids}, + snippets_by_id={ + cid: snippets_by_id.get(cid, "") for cid in picked_ids + }, ) - answers_by_query.append({ - "query": q, - "answer": ans_i, - "citations": cits_i, - }) + answers_by_query.append( + { + "query": q, + "answer": ans_i, + "citations": cits_i, + } + ) except Exception as _e: - answers_by_query.append({ - "query": q, - "answer": "", - "citations": [], - "error": str(_e), - }) + answers_by_query.append( + { + "query": q, + "answer": "", + "citations": [], + "error": str(_e), + } + ) except Exception: answers_by_query = None @@ -5605,7 +6681,7 @@ async def code_search( not_: Any = None, case: Any = None, compact: Any = None, - **kwargs, + kwargs: Any = None, ) -> Dict[str, Any]: """Exact alias of repo_search (hybrid code search). @@ -5635,12 +6711,13 @@ async def code_search( not_=not_, case=case, compact=compact, - **kwargs, + kwargs=kwargs, ) -if __name__ == "__main__": +_relax_var_kwarg_defaults() +if __name__ == "__main__": # Optional warmups: gated by env flags to avoid delaying readiness on fresh containers try: if str(os.environ.get("EMBEDDING_WARMUP", "")).strip().lower() in { diff --git a/scripts/refrag_llamacpp.py b/scripts/refrag_llamacpp.py index 199251a2..9749e393 100644 --- a/scripts/refrag_llamacpp.py +++ b/scripts/refrag_llamacpp.py @@ -11,6 +11,9 @@ from __future__ import annotations import os from typing import Any, Dict, Optional +import threading +import contextlib +import time def _bool_env(name: str, default: str = "0") -> bool: @@ -42,6 +45,72 @@ def get_sense_policy() -> str: return str(os.environ.get("REFRAG_SENSE", "heuristic")).strip().lower() +def _max_parallel() -> int: + try: + val = int(os.environ.get("LLAMACPP_MAX_PARALLEL", "").strip() or "2") + return max(1, val) + except Exception: + return 2 + + +_LLAMACPP_PARALLEL = _max_parallel() +_LLAMACPP_SLOT = threading.Semaphore(_LLAMACPP_PARALLEL) +_LLAMACPP_SLOT_LOCK = threading.Lock() + + +def _refresh_parallel_semaphore() -> None: + """Refresh semaphore when LLAMACPP_MAX_PARALLEL changes at runtime.""" + try: + desired = _max_parallel() + except Exception: + desired = 2 + with _LLAMACPP_SLOT_LOCK: + global _LLAMACPP_SLOT, _LLAMACPP_PARALLEL + if desired == _LLAMACPP_PARALLEL: + return + _LLAMACPP_PARALLEL = desired + _LLAMACPP_SLOT = threading.Semaphore(desired) + + +@contextlib.contextmanager +def _parallel_slot(): + """Context manager honoring LLAMACPP_MAX_PARALLEL.""" + _refresh_parallel_semaphore() + slot = globals().get("_LLAMACPP_SLOT") + if not isinstance(slot, threading.Semaphore): + slot = threading.Semaphore(_max_parallel()) + globals()["_LLAMACPP_SLOT"] = slot + acquired = slot.acquire(timeout=float(os.environ.get("LLAMACPP_ACQUIRE_TIMEOUT", "30") or 30)) + if not acquired: + raise RuntimeError("llama.cpp saturated: parallel limit reached") + try: + yield + finally: + slot.release() + + +_WARM_CHECKED = False + + +def _maybe_warm(base_url: str) -> None: + global _WARM_CHECKED + if _WARM_CHECKED: + return + _WARM_CHECKED = True + if not _bool_env("LLAMACPP_AUTOWARM", "1"): + return + try: + from urllib import request + + req = request.Request(base_url.rstrip("/") + "/health", method="GET") + timeout = float(os.environ.get("LLAMACPP_WARM_TIMEOUT", "3") or 3) + with request.urlopen(req, timeout=timeout): + pass + except Exception: + # Ignore warm failures; decoder calls will raise later if truly unavailable + return + + class LlamaCppRefragClient: """Feature-flagged client for llama.cpp decoder. @@ -51,9 +120,22 @@ class LlamaCppRefragClient: """ def __init__(self, base_url: Optional[str] = None) -> None: - self.base_url = base_url or os.environ.get( - "LLAMACPP_URL", "http://localhost:8080" - ) + if base_url: + self.base_url = base_url + else: + # Smart URL resolution: GPU vs Docker based on USE_GPU_DECODER flag + use_gpu = str(os.environ.get("USE_GPU_DECODER", "0")).strip().lower() + if use_gpu in {"1", "true", "yes", "on"}: + # Use native GPU-accelerated server + # Use localhost when running on host, host.docker.internal when in container + if os.path.exists("/.dockerenv"): + self.base_url = "http://host.docker.internal:8081" + else: + self.base_url = "http://localhost:8081" + else: + # Use configured LLAMACPP_URL (default: Docker CPU-only) + self.base_url = os.environ.get("LLAMACPP_URL", "http://localhost:8080") + _maybe_warm(self.base_url) if get_runtime_kind() != "llamacpp": raise ValueError( "REFRAG_RUNTIME must be 'llamacpp' for LlamaCppRefragClient" @@ -104,7 +186,13 @@ def generate_with_soft_embeddings( "stop": gen_kwargs.get("stop") or [], } try: - res = self._post("/soft_completion", payload) + with _parallel_slot(): + _start = time.time() + res = self._post("/soft_completion", payload) + elapsed = time.time() - _start + os.environ.setdefault( + "LLAMACPP_LAST_LATENCY_SEC", f"{elapsed:.3f}" + ) except Exception as e: raise RuntimeError(f"llama.cpp soft_completion failed: {e}") return (res.get("content") or res.get("generation") or "").strip() @@ -128,7 +216,13 @@ def generate_with_soft_embeddings( "stop": gen_kwargs.get("stop") or [], } try: - res = self._post("/completion", payload) + with _parallel_slot(): + _start = time.time() + res = self._post("/completion", payload) + elapsed = time.time() - _start + os.environ.setdefault( + "LLAMACPP_LAST_LATENCY_SEC", f"{elapsed:.3f}" + ) except Exception as e: raise RuntimeError(f"llama.cpp completion failed: {e}") # llama.cpp server returns { 'content': '...' } or { 'token': ... } streams; we expect non-stream diff --git a/scripts/utils.py b/scripts/utils.py index 69c13610..dfdad7fc 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -69,6 +69,7 @@ def lex_hash_vector_queries(phrases: list[str], dim: int = 4096) -> list[float]: def highlight_snippet(snippet: str, tokens: list[str]) -> str: if not snippet or not tokens: return snippet + # longest first to avoid partial overlaps toks = sorted(set(tokens), key=len, reverse=True) import re as _re diff --git a/scripts/watch_index.py b/scripts/watch_index.py index e94ccc09..6052441d 100644 --- a/scripts/watch_index.py +++ b/scripts/watch_index.py @@ -58,8 +58,8 @@ def add(self, p: Path): if self._timer is not None: try: self._timer.cancel() - except Exception: - pass + except Exception as e: + logger.error(f"Failed to cancel timer in ChangeQueue.add: {e}") self._timer = threading.Timer(DELAY_SECS, self._flush) self._timer.daemon = True self._timer.start() @@ -88,8 +88,10 @@ def _flush(self): except Exception as e: try: print(f"[watcher_error] processing batch failed: {e}") - except Exception: - pass + except Exception as inner_e: + logger.error( + f"Exception in ChangeQueue._flush during batch processing: {inner_e}" + ) # drain any pending accumulated during processing with self._lock: if not self._pending: @@ -101,7 +103,9 @@ def _flush(self): class IndexHandler(FileSystemEventHandler): - def __init__(self, root: Path, queue: ChangeQueue, client: QdrantClient, collection: str): + def __init__( + self, root: Path, queue: ChangeQueue, client: QdrantClient, collection: str + ): super().__init__() self.root = root self.queue = queue @@ -115,14 +119,18 @@ def __init__(self, root: Path, queue: ChangeQueue, client: QdrantClient, collect except Exception: self._ignore_path = None self._ignore_mtime = ( - (self._ignore_path.stat().st_mtime if self._ignore_path and self._ignore_path.exists() else 0.0) + self._ignore_path.stat().st_mtime + if self._ignore_path and self._ignore_path.exists() + else 0.0 ) def _maybe_reload_excluder(self): try: if not self._ignore_path: return - cur = self._ignore_path.stat().st_mtime if self._ignore_path.exists() else 0.0 + cur = ( + self._ignore_path.stat().st_mtime if self._ignore_path.exists() else 0.0 + ) if cur != self._ignore_mtime: self.excl = idx._Excluder(self.root) self._ignore_mtime = cur @@ -212,17 +220,24 @@ def on_moved(self, event): except Exception: return # Only react to code files - if dest.suffix.lower() not in idx.CODE_EXTS and src.suffix.lower() not in idx.CODE_EXTS: + if ( + dest.suffix.lower() not in idx.CODE_EXTS + and src.suffix.lower() not in idx.CODE_EXTS + ): return # If destination directory is ignored, treat as simple deletion try: - rel_dir = "/" + str(dest.parent.resolve().relative_to(self.root.resolve())).replace(os.sep, "/") + rel_dir = "/" + str( + dest.parent.resolve().relative_to(self.root.resolve()) + ).replace(os.sep, "/") if rel_dir == "/.": rel_dir = "/" if self.excl.exclude_dir(rel_dir): if src.suffix.lower() in idx.CODE_EXTS: try: - idx.delete_points_by_path(self.client, self.collection, str(src)) + idx.delete_points_by_path( + self.client, self.collection, str(src) + ) print(f"[moved:ignored_dest_deleted_src] {src} -> {dest}") try: remove_cached_file(str(self.root), str(src)) @@ -261,7 +276,12 @@ def on_moved(self, event): except Exception: pass try: - _log_activity(str(self.root), "moved", dest, {"from": str(src), "chunks": int(moved_count)}) + _log_activity( + str(self.root), + "moved", + dest, + {"from": str(src), "chunks": int(moved_count)}, + ) except Exception: pass return @@ -277,6 +297,7 @@ def on_moved(self, event): except Exception: pass + # --- Workspace state helpers --- def _set_status_indexing(workspace_path: str, total_files: int) -> None: try: @@ -293,7 +314,11 @@ def _set_status_indexing(workspace_path: str, total_files: int) -> None: def _update_progress( - workspace_path: str, started_at: str, processed: int, total: int, current_file: Path | None + workspace_path: str, + started_at: str, + processed: int, + total: int, + current_file: Path | None, ) -> None: try: update_indexing_status( @@ -312,7 +337,9 @@ def _update_progress( pass -def _log_activity(workspace_path: str, action: str, file_path: Path, details: dict | None = None) -> None: +def _log_activity( + workspace_path: str, action: str, file_path: Path, details: dict | None = None +) -> None: try: update_last_activity( workspace_path, @@ -328,7 +355,9 @@ def _log_activity(workspace_path: str, action: str, file_path: Path, details: di # --- Move/Rename optimization: reuse vectors when file content unchanged --- -def _rename_in_store(client: QdrantClient, collection: str, src: Path, dest: Path) -> int: +def _rename_in_store( + client: QdrantClient, collection: str, src: Path, dest: Path +) -> int: """Best-effort: if dest content hash matches previously indexed src hash, update points in-place to the new path without re-embedding. @@ -369,7 +398,7 @@ def _rename_in_store(client: QdrantClient, collection: str, src: Path, dest: Pat new_points = [] for rec in points: payload = rec.payload or {} - md = (payload.get("metadata") or {}) + md = payload.get("metadata") or {} code = md.get("code") or "" try: start_line = int(md.get("start_line") or 1) @@ -384,7 +413,9 @@ def _rename_in_store(client: QdrantClient, collection: str, src: Path, dest: Pat new_md["path_prefix"] = str(dest.parent) # Recompute dual-path hints cur_path = str(dest) - host_root = str(os.environ.get("HOST_INDEX_PATH") or "").strip().rstrip("/") + host_root = ( + str(os.environ.get("HOST_INDEX_PATH") or "").strip().rstrip("/") + ) host_path = None container_path = None try: @@ -490,7 +521,9 @@ def main(): except Exception: pass - q = ChangeQueue(lambda paths: _process_paths(paths, client, model, vector_name, str(ROOT))) + q = ChangeQueue( + lambda paths: _process_paths(paths, client, model, vector_name, str(ROOT)) + ) handler = IndexHandler(ROOT, q, client, COLLECTION) obs = Observer() @@ -546,12 +579,19 @@ def _process_paths(paths, client, model, vector_name: str, workspace_path: str): # Lazily instantiate model if needed if model is None: from fastembed import TextEmbedding + mname = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-base-en-v1.5") model = TextEmbedding(model_name=mname) ok = False try: ok = idx.index_single_file( - client, model, COLLECTION, vector_name, p, dedupe=True, skip_unchanged=True + client, + model, + COLLECTION, + vector_name, + p, + dedupe=True, + skip_unchanged=True, ) except Exception as e: try: @@ -568,7 +608,9 @@ def _process_paths(paths, client, model, vector_name: str, workspace_path: str): size = None _log_activity(workspace_path, "indexed", p, {"file_size": size}) else: - _log_activity(workspace_path, "skipped", p, {"reason": "no-change-or-error"}) + _log_activity( + workspace_path, "skipped", p, {"reason": "no-change-or-error"} + ) processed += 1 _update_progress(workspace_path, started_at, processed, total, current) finally: diff --git a/test_gpu_switch.py b/test_gpu_switch.py new file mode 100644 index 00000000..9d1027ce --- /dev/null +++ b/test_gpu_switch.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +""" +Test script to verify GPU decoder switching functionality. + +Usage: + # Test Docker CPU-only decoder + python test_gpu_switch.py + + # Test native GPU-accelerated decoder + USE_GPU_DECODER=1 python test_gpu_switch.py +""" + +import os +import sys + +def load_env_file(): + """Load environment variables from .env file.""" + env_file = '.env' + if os.path.exists(env_file): + with open(env_file, 'r') as f: + for line in f: + line = line.strip() + if line and not line.startswith('#') and '=' in line: + key, value = line.split('=', 1) + # Only set if not already in environment + if key not in os.environ: + os.environ[key] = value + +def test_decoder_url_resolution(): + """Test that the decoder URL is resolved correctly based on USE_GPU_DECODER flag.""" + + # Import the resolver function + sys.path.insert(0, 'scripts') + from refrag_llamacpp import LlamaCppRefragClient + + # Test current configuration + use_gpu = os.environ.get("USE_GPU_DECODER", "0") + print(f"USE_GPU_DECODER = {use_gpu}") + + # Create client and check URL + client = LlamaCppRefragClient() + print(f"Resolved decoder URL: {client.base_url}") + + # Test health endpoint + try: + import urllib.request + + health_url = client.base_url.rstrip('/') + '/health' + + # For Docker service names, try localhost equivalent when running on host + if 'llamacpp:8080' in health_url: + health_url = health_url.replace('llamacpp:8080', 'localhost:8080') + print(f"Testing health endpoint: {health_url} (Docker service via localhost)") + else: + print(f"Testing health endpoint: {health_url}") + + req = urllib.request.Request(health_url, method='GET') + with urllib.request.urlopen(req, timeout=5) as resp: + if resp.status == 200: + print("PASS: Decoder server is healthy and reachable") + return True + else: + print(f"FAIL: Decoder server returned status {resp.status}") + return False + + except Exception as e: + print(f"FAIL: Failed to reach decoder server: {e}") + return False + +def test_simple_completion(): + """Test a simple completion request.""" + + sys.path.insert(0, 'scripts') + from refrag_llamacpp import LlamaCppRefragClient, is_decoder_enabled + + if not is_decoder_enabled(): + print("FAIL: Decoder is disabled. Set REFRAG_DECODER=1 to enable.") + return False + + try: + client = LlamaCppRefragClient() + + # For Docker service names, use localhost equivalent when running on host + test_url = client.base_url + if 'llamacpp:8080' in client.base_url: + test_url = client.base_url.replace('llamacpp:8080', 'localhost:8080') + # Override the client's base_url for testing + client.base_url = test_url + print(f"Testing completion with decoder at: {test_url} (Docker service via localhost)") + else: + print(f"Testing completion with decoder at: {client.base_url}") + + response = client.generate_with_soft_embeddings( + prompt="What is 2+2?", + max_tokens=50, + temperature=0.1 + ) + + print(f"PASS: Completion successful: {response[:100]}...") + return True + + except Exception as e: + print(f"FAIL: Completion failed: {e}") + return False + +if __name__ == "__main__": + print("Testing GPU decoder switching functionality\n") + + # Load .env file first + load_env_file() + + # Set decoder enabled for testing + os.environ.setdefault("REFRAG_DECODER", "1") + + print("1. Testing URL resolution...") + url_ok = test_decoder_url_resolution() + + print("\n2. Testing simple completion...") + completion_ok = test_simple_completion() + + print(f"\nResults:") + print(f" URL Resolution: {'PASS' if url_ok else 'FAIL'}") + print(f" Completion Test: {'PASS' if completion_ok else 'FAIL'}") + + if url_ok and completion_ok: + print("\nAll tests passed! GPU switching is working correctly.") + sys.exit(0) + else: + print("\nSome tests failed. Check your decoder setup.") + sys.exit(1) diff --git a/tests/test_change_history_for_path.py b/tests/test_change_history_for_path.py new file mode 100644 index 00000000..a0f9c046 --- /dev/null +++ b/tests/test_change_history_for_path.py @@ -0,0 +1,95 @@ +import importlib +import sys +import types +import pytest + +# Provide a minimal stub for mcp.server.fastmcp.FastMCP so importing the server doesn't exit +mcp_pkg = types.ModuleType("mcp") +server_pkg = types.ModuleType("mcp.server") +fastmcp_pkg = types.ModuleType("mcp.server.fastmcp") + +class _FastMCP: + def __init__(self, *args, **kwargs): + pass + def tool(self, *args, **kwargs): + def _decorator(fn): + return fn + return _decorator + +setattr(fastmcp_pkg, "FastMCP", _FastMCP) +sys.modules.setdefault("mcp", mcp_pkg) +sys.modules.setdefault("mcp.server", server_pkg) +sys.modules.setdefault("mcp.server.fastmcp", fastmcp_pkg) + +srv = importlib.import_module("scripts.mcp_indexer_server") + + +class FakePoint: + def __init__(self, payload): + self.payload = payload + + +class FakeQdrant: + def __init__(self, pages): + # pages: list[list[FakePoint]] to simulate pagination + self._pages = pages + self._i = 0 + + def scroll(self, **kwargs): + if self._i >= len(self._pages): + return ([], None) + page = self._pages[self._i] + self._i += 1 + # Return next_page offset as None to stop after last page + return (page, None if self._i >= len(self._pages) else object()) + + +@pytest.mark.service +@pytest.mark.anyio +async def test_change_history_strict_match_under_work(monkeypatch): + import qdrant_client + + pts = [ + FakePoint({ + "metadata": { + "path": "/work/a.py", + "file_hash": "h1", + "last_modified_at": 100, + "ingested_at": 90, + "churn_count": 2, + } + }), + FakePoint({ + "metadata": { + "path": "/work/a.py", + "file_hash": "h2", + "last_modified_at": 120, + "ingested_at": 110, + "churn_count": 3, + } + }), + FakePoint({ + "metadata": { + "path": "/work/a.py", + "file_hash": "h1", + "last_modified_at": 130, + "ingested_at": 115, + "churn_count": 5, + } + }), + ] + + monkeypatch.setattr(qdrant_client, "QdrantClient", lambda *a, **k: FakeQdrant([pts])) + + res = await srv.change_history_for_path(path="/work/a.py", max_points=100) + assert res.get("ok") is True + summary = res.get("summary") or {} + assert summary.get("path") == "/work/a.py" + assert summary.get("points_scanned") == 3 + assert summary.get("distinct_hashes") == 2 # h1,h2 + assert summary.get("last_modified_min") == 100 + assert summary.get("last_modified_max") == 130 + assert summary.get("ingested_min") == 90 + assert summary.get("ingested_max") == 115 + assert summary.get("churn_count_max") == 5 + diff --git a/tests/test_context_answer_path_mention.py b/tests/test_context_answer_path_mention.py new file mode 100644 index 00000000..ec5df18e --- /dev/null +++ b/tests/test_context_answer_path_mention.py @@ -0,0 +1,40 @@ +import importlib +import pytest + +srv = importlib.import_module("scripts.mcp_indexer_server") + + +@pytest.mark.service +def test_context_answer_path_mention_fallback(monkeypatch): + # Force retrieval to return nothing so path-mention fallback engages + import scripts.hybrid_search as hs + monkeypatch.setattr(hs, "run_hybrid_search", lambda **k: []) + + import scripts.refrag_llamacpp as ref + + class FakeLlama: + def __init__(self, *a, **k): + pass + + def generate_with_soft_embeddings(self, prompt: str, max_tokens: int = 64, **kw): + # Should still include Sources and [1] with the mentioned file + assert "Sources:" in prompt + assert "[1]" in prompt + return "ok [1]" + + monkeypatch.setattr(ref, "LlamaCppRefragClient", FakeLlama) + monkeypatch.setattr(ref, "is_decoder_enabled", lambda: True) + + # Mention an actual file in this repo so fallback can find it + q = "explain something in scripts/hybrid_search.py" + out = srv.asyncio.get_event_loop().run_until_complete( + srv.context_answer(query=q, limit=3, per_path=2) + ) + assert isinstance(out, dict) + cits = out.get("citations") or [] + assert len(cits) >= 1 + # Either path or rel_path should indicate the file + p = cits[0].get("path") or "" + rp = cits[0].get("rel_path") or "" + assert p.endswith("scripts/hybrid_search.py") or rp.endswith("scripts/hybrid_search.py") + diff --git a/tests/test_reranker_verification.py b/tests/test_reranker_verification.py new file mode 100644 index 00000000..b7c24123 --- /dev/null +++ b/tests/test_reranker_verification.py @@ -0,0 +1,104 @@ +import importlib +import os +import sys +import types +import pytest + +# Provide a minimal stub for mcp.server.fastmcp.FastMCP so importing the server doesn't exit +mcp_pkg = types.ModuleType("mcp") +server_pkg = types.ModuleType("mcp.server") +fastmcp_pkg = types.ModuleType("mcp.server.fastmcp") + +class _FastMCP: + def __init__(self, *args, **kwargs): + pass + def tool(self, *args, **kwargs): + def _decorator(fn): + return fn + return _decorator + +setattr(fastmcp_pkg, "FastMCP", _FastMCP) +sys.modules.setdefault("mcp", mcp_pkg) +sys.modules.setdefault("mcp.server", server_pkg) +sys.modules.setdefault("mcp.server.fastmcp", fastmcp_pkg) + +srv = importlib.import_module("scripts.mcp_indexer_server") + + +@pytest.mark.service +@pytest.mark.anyio +async def test_rerank_inproc_changes_order(monkeypatch): + # Force in-process hybrid + in-process rerank paths + monkeypatch.setenv("HYBRID_IN_PROCESS", "1") + monkeypatch.setenv("RERANK_IN_PROCESS", "1") + + # Baseline hybrid results (JSON structured items); A before B + def fake_run_hybrid_search(**kwargs): + return [ + { + "score": 0.6, + "path": "/work/a.py", + "symbol": "", + "start_line": 1, + "end_line": 3, + }, + { + "score": 0.5, + "path": "/work/b.py", + "symbol": "", + "start_line": 10, + "end_line": 12, + }, + ] + + # Reranker returns higher score for B than A to force reordering + def fake_rerank_local(pairs): + # pairs order corresponds to hybrid results; return scores [A,B] -> [0.10, 0.90] + return [0.10, 0.90] + + # Patch hybrid and rerank + monkeypatch.setenv("EMBEDDING_MODEL", "BAAI/bge-base-en-v1.5") + monkeypatch.setattr( + importlib.import_module("scripts.hybrid_search"), "run_hybrid_search", fake_run_hybrid_search + ) + monkeypatch.setattr(importlib.import_module("scripts.rerank_local"), "rerank_local", fake_rerank_local) + + # Baseline (rerank disabled) preserves hybrid order A then B + base = await srv.repo_search(query="q", limit=2, per_path=2, rerank_enabled=False, compact=True) + assert [r["path"] for r in base["results"]] == ["/work/a.py", "/work/b.py"] + + # With rerank enabled, order should flip to B then A; counters should show inproc_hybrid + rr = await srv.repo_search(query="q", limit=2, per_path=2, rerank_enabled=True, compact=True) + assert rr.get("used_rerank") is True + assert rr.get("rerank_counters", {}).get("inproc_hybrid", 0) >= 1 + assert [r["path"] for r in rr["results"]] == ["/work/b.py", "/work/a.py"] + + +@pytest.mark.service +@pytest.mark.anyio +async def test_rerank_subprocess_timeout_fallback(monkeypatch): + # Force hybrid via subprocess output (doesn't matter which) and disable inproc rerank + monkeypatch.setenv("HYBRID_IN_PROCESS", "1") + monkeypatch.setenv("RERANK_IN_PROCESS", "0") + + def fake_run_hybrid_search(**kwargs): + return [ + {"score": 0.6, "path": "/work/a.py", "symbol": "", "start_line": 1, "end_line": 3}, + {"score": 0.5, "path": "/work/b.py", "symbol": "", "start_line": 10, "end_line": 12}, + ] + + async def fake_run_async(cmd, env=None, timeout=None): + # Simulate subprocess reranker timing out + return {"ok": False, "code": -1, "stdout": "", "stderr": f"Command timed out after {timeout}s"} + + monkeypatch.setattr( + importlib.import_module("scripts.hybrid_search"), "run_hybrid_search", fake_run_hybrid_search + ) + monkeypatch.setattr(srv, "_run_async", fake_run_async) + + rr = await srv.repo_search(query="q", limit=2, per_path=2, rerank_enabled=True, compact=True) + # Fallback should keep original order from hybrid; timeout counter incremented + assert rr.get("used_rerank") is False + assert rr.get("rerank_counters", {}).get("timeout", 0) >= 1 + assert [r["path"] for r in rr["results"]] == ["/work/a.py", "/work/b.py"] + diff --git a/tests/test_server_helpers.py b/tests/test_server_helpers.py index 464c9bd6..1c50740c 100644 --- a/tests/test_server_helpers.py +++ b/tests/test_server_helpers.py @@ -63,7 +63,7 @@ def test_repo_search_arg_normalization(monkeypatch, tmp_path): kind=None, symbol=None, ext=None, - not_filter=None, + not_=None, # Fixed: was not_filter case=None, path_regex=None, path_glob=None, diff --git a/tests/test_service_context_search.py b/tests/test_service_context_search.py index c2f50c9f..1a5e5c86 100644 --- a/tests/test_service_context_search.py +++ b/tests/test_service_context_search.py @@ -2,6 +2,14 @@ import json import pytest +# Import fastmcp BEFORE scripts.mcp_indexer_server to avoid import conflicts +# (scripts.mcp_indexer_server imports from mcp.server.fastmcp which can cause +# the mcp module to be in a partially initialized state) +try: + import fastmcp +except ImportError: + fastmcp = None # Will be handled in tests that need it + srv = importlib.import_module("scripts.mcp_indexer_server") @@ -118,7 +126,9 @@ async def list_tools(self): async def call_tool(self, *a, **k): return Resp() - import fastmcp + # fastmcp is already imported at module level + if fastmcp is None: + pytest.skip("fastmcp not available") monkeypatch.setattr(fastmcp, "Client", lambda *a, **k: FakeClient()) @@ -182,7 +192,9 @@ async def list_tools(self): async def call_tool(self, *a, **k): return Resp() - import fastmcp + # fastmcp is already imported at module level + if fastmcp is None: + pytest.skip("fastmcp not available") monkeypatch.setattr(fastmcp, "Client", lambda *a, **k: FakeClient())