diff --git a/.env.example b/.env.example index 7e23581d..562d84b6 100644 --- a/.env.example +++ b/.env.example @@ -149,6 +149,11 @@ SEMANTIC_EXPANSION_CACHE_TTL=3600 # HYBRID_RECENCY_WEIGHT=0.1 # RERANK_EXPAND=1 +# Elbow detection filter: adaptive threshold based on score distribution (Kneedle algorithm) +# Filters out low-relevance results by detecting the "elbow" point in the score curve +# Improves precision by only returning results above the natural relevance drop-off +# HYBRID_ELBOW_FILTER=0 + # Caching (embeddings and search results) # MAX_EMBED_CACHE=16384 # HYBRID_RESULTS_CACHE=128 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 47b85ae8..a3b785ae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,14 +65,18 @@ jobs: python -c "from fastembed import TextEmbedding; m = TextEmbedding(model_name='BAAI/bge-base-en-v1.5'); list(m.embed(['test']))" - name: Run tests - run: pytest -q - + run: pytest -q --junitxml=test-results.xml + - name: Upload test results uses: actions/upload-artifact@v4 if: always() with: name: test-results - path: | - .pytest_cache/ - test-results.xml + path: test-results.xml retention-days: 7 + + - name: Test Summary + uses: test-summary/action@v2 + if: always() + with: + paths: test-results.xml diff --git a/ctx-mcp-bridge/package.json b/ctx-mcp-bridge/package.json index 04504c10..6c4e93cd 100644 --- a/ctx-mcp-bridge/package.json +++ b/ctx-mcp-bridge/package.json @@ -1,6 +1,6 @@ { "name": "@context-engine-bridge/context-engine-mcp-bridge", - "version": "0.0.15", + "version": "0.0.16", "description": "Context Engine MCP bridge (http/stdio proxy combining indexer + memory servers)", "bin": { "ctxce": "bin/ctxce.js", @@ -8,7 +8,8 @@ }, "type": "module", "scripts": { - "start": "node bin/ctxce.js" + "start": "node bin/ctxce.js", + "postinstall": "node -e \"try{require('fs').chmodSync('bin/ctxce.js',0o755)}catch(e){}\"" }, "dependencies": { "@modelcontextprotocol/sdk": "^1.24.3", @@ -20,4 +21,4 @@ "engines": { "node": ">=18.0.0" } -} +} \ No newline at end of file diff --git a/deploy/kubernetes/configmap.yaml b/deploy/kubernetes/configmap.yaml index 26c9e637..caf3e4c3 100644 --- a/deploy/kubernetes/configmap.yaml +++ b/deploy/kubernetes/configmap.yaml @@ -151,3 +151,4 @@ data: USE_GPU_DECODER: '0' USE_TREE_SITTER: '1' WATCH_DEBOUNCE_SECS: '4' + PSEUDO_DEFER_TO_WORKER: '1' diff --git a/docker-compose.yml b/docker-compose.yml index ccb219b5..1075675b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -453,9 +453,12 @@ services: - LEX_SPARSE_NAME=${LEX_SPARSE_NAME:-} # Pattern vectors for structural code similarity - PATTERN_VECTORS=${PATTERN_VECTORS:-} - # Graph edges for symbol relationships - - INDEX_GRAPH_EDGES=${INDEX_GRAPH_EDGES:-1} + # Graph edges for symbol relationships (always on) - INDEX_GRAPH_EDGES_MODE=${INDEX_GRAPH_EDGES_MODE:-symbol} + # Defer pseudo-tag generation to watcher worker for faster initial indexing + - PSEUDO_DEFER_TO_WORKER=${PSEUDO_DEFER_TO_WORKER:-1} + # Parallel indexing - number of worker threads (default: 4, use -1 for CPU count) + - INDEX_WORKERS=${INDEX_WORKERS:-4} volumes: - workspace_pvc:/work:rw - codebase_pvc:/work/.codebase:rw @@ -514,12 +517,13 @@ services: - LEX_SPARSE_NAME=${LEX_SPARSE_NAME:-} # Pattern vectors for structural code similarity - PATTERN_VECTORS=${PATTERN_VECTORS:-} - # Graph edges for symbol relationships - - INDEX_GRAPH_EDGES=${INDEX_GRAPH_EDGES:-1} + # Graph edges for symbol relationships (always on - Qdrant flat graph) - INDEX_GRAPH_EDGES_MODE=${INDEX_GRAPH_EDGES_MODE:-symbol} - GRAPH_BACKFILL_ENABLED=${GRAPH_BACKFILL_ENABLED:-1} - # Neo4j graph backend (when set, edges go to Neo4j instead of Qdrant _graph collection) + # Neo4j graph backend (optional - takes precedence over Qdrant flat graph) - NEO4J_GRAPH=${NEO4J_GRAPH:-} + # Defer pseudo-tag generation - watcher runs backfill worker thread + - PSEUDO_DEFER_TO_WORKER=${PSEUDO_DEFER_TO_WORKER:-1} volumes: - workspace_pvc:/work:rw - codebase_pvc:/work/.codebase:rw diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 0e62d3d6..cbe83710 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -377,12 +377,26 @@ REFRAG_RUNTIME=glm # or openai, minimax, llamacpp ### Pseudo Backfill Worker -Deferred pseudo/tag generation runs asynchronously after initial indexing. +Deferred pseudo/tag generation runs asynchronously after initial indexing. This significantly speeds up initial indexing by skipping LLM-based pseudo-tag generation during the indexer run, deferring it to a background worker thread in the watcher service. | Name | Description | Default | |------|-------------|---------| | PSEUDO_BACKFILL_ENABLED | Enable async pseudo/tag backfill worker | 0 (disabled) | -| PSEUDO_DEFER_TO_WORKER | Skip inline pseudo, defer to backfill worker | 0 (disabled) | +| PSEUDO_DEFER_TO_WORKER | Skip inline pseudo, defer to backfill worker | 1 (enabled) | +| GRAPH_BACKFILL_ENABLED | Enable graph edge backfill in watcher worker | 1 (enabled) | + +**How it works:** +1. When `PSEUDO_DEFER_TO_WORKER=1`, the indexer generates only base chunks (no pseudo-tags) +2. The watcher service starts a `_start_pseudo_backfill_worker` daemon thread +3. This thread periodically calls `pseudo_backfill_tick()` to enrich chunks with LLM-generated tags +4. If `GRAPH_BACKFILL_ENABLED=1`, it also calls `graph_backfill_tick()` to populate symbol graph edges + +**Benefits:** +- Initial indexing is 2-5x faster (no LLM calls blocking indexer) +- Background enrichment happens continuously without blocking searches +- Failed LLM calls don't break indexing; worker retries automatically + +**Recommended for production:** Enable both for fastest initial indexing with eventual enrichment. ### Adaptive Span Sizing @@ -523,6 +537,7 @@ Useful for Kubernetes deployments where a shared filesystem is not reliable. | CODEBASE_STATE_REDIS_LOCK_WAIT_MS | Redis lock wait in ms | 2000 | | CODEBASE_STATE_REDIS_SOCKET_TIMEOUT | Redis socket timeout in seconds | 2 | | CODEBASE_STATE_REDIS_CONNECT_TIMEOUT | Redis connect timeout in seconds | 2 | +| CODEBASE_STATE_REDIS_MAX_CONNECTIONS | Redis connection pool size limit | 10 | ### Semantic Expansion diff --git a/pyproject.toml b/pyproject.toml index 2b57fdf3..abef2155 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dependencies = [ "rich>=13.0.0", "typer>=0.9.0", "requests>=2.28.0", + "xxhash>=3.0.0", ] [project.optional-dependencies] diff --git a/scripts/ast_analyzer.py b/scripts/ast_analyzer.py index 0b87e443..c1de92e7 100644 --- a/scripts/ast_analyzer.py +++ b/scripts/ast_analyzer.py @@ -26,6 +26,19 @@ logger = logging.getLogger("ast_analyzer") +# --------------------------------------------------------------------------- +# Language Mappings Integration +# --------------------------------------------------------------------------- +# Context-Engine's unified concept-based extraction supporting 32 languages. +# Uses declarative tree-sitter queries organized by semantic concept type: +# DEFINITION, BLOCK, COMMENT, IMPORT, STRUCTURE +_LANGUAGE_MAPPINGS_AVAILABLE = False +try: + from scripts.ingest.language_mappings import get_mapping, supported_languages as lm_supported_languages, ConceptType + _LANGUAGE_MAPPINGS_AVAILABLE = True +except ImportError: + pass + # Optional tree-sitter support - tree-sitter 0.25+ API _TS_LANGUAGES: Dict[str, Any] = {} _TS_AVAILABLE = False @@ -131,6 +144,27 @@ class CodeSymbol: parent: Optional[str] = None # Parent class/module complexity: int = 0 # Cyclomatic complexity estimate content_hash: Optional[str] = None + concept: Optional[str] = None # Universal concept type (definition, block, comment, etc.) + + +@dataclass +class ConceptUnit: + """A semantic code unit with universal concept classification. + + Context-Engine's 5 universal concepts for language-agnostic analysis: + - DEFINITION: functions, classes, types, constants + - BLOCK: control flow, scoped regions + - COMMENT: comments, docstrings + - IMPORT: import/include statements + - STRUCTURE: file-level organization + """ + concept: str # definition, block, comment, import, structure + name: str + content: str + start_line: int + end_line: int + kind: str = "" # More specific: function, class, if, for, etc. + metadata: Dict[str, Any] = field(default_factory=dict) @dataclass @@ -223,7 +257,13 @@ def analyze_file( logger.error(f"Failed to read {file_path}: {e}") return self._empty_analysis() - # Route to appropriate analyzer + # Use language mappings (32 languages, declarative queries) + if _LANGUAGE_MAPPINGS_AVAILABLE and self.use_tree_sitter: + result = self._analyze_with_mapping(content, file_path, language) + if result and (result.get("symbols") or result.get("imports") or result.get("calls")): + return result + + # Fallback to legacy per-language analyzers if language == "python": return self._analyze_python(content, file_path) elif language in ("javascript", "typescript") and self.use_tree_sitter: @@ -396,6 +436,511 @@ def extract_dependencies( "local": list(set(local)) } + # ---- Language Mappings Analysis (unified, concept-based) ---- + + def _analyze_with_mapping(self, content: str, file_path: str, language: str) -> Dict[str, Any]: + """Analyze code using language mappings (concept-based extraction). + + This uses the declarative tree-sitter queries from language_mappings + to extract symbols, imports, and calls. Supports 34 languages. + """ + if not _LANGUAGE_MAPPINGS_AVAILABLE: + return self._empty_analysis() + + try: + mapping = get_mapping(self._normalize_lang(language)) + except (TypeError, Exception) as e: + logger.debug(f"Mapping instantiation failed for {language}: {e}") + return self._empty_analysis() + + if not mapping: + return self._empty_analysis() + + # Get parser for this language + parser = self._get_ts_parser(language) + if not parser: + return self._empty_analysis() + + try: + tree = parser.parse(content.encode("utf-8")) + root = tree.root_node + except Exception as e: + logger.debug(f"Tree-sitter parse failed for {language}: {e}") + return self._empty_analysis() + + content_bytes = content.encode("utf-8") + symbols: List[CodeSymbol] = [] + imports: List[ImportReference] = [] + calls: List[CallReference] = [] + + # Get tree-sitter language object for queries + ts_lang = _TS_LANGUAGES.get(language) or _TS_LANGUAGES.get(self._normalize_lang(language)) + if not ts_lang: + return self._empty_analysis() + + try: + from tree_sitter import Query, QueryCursor + except ImportError: + return self._empty_analysis() + + # Extract DEFINITION concepts -> symbols + def_query_str = mapping.get_query_for_concept(ConceptType.DEFINITION) + if def_query_str: + try: + query = Query(ts_lang, def_query_str) + cursor = QueryCursor(query) + seen_ranges: Set[Tuple[int, int]] = set() + + for match in cursor.matches(root): + _, captures_dict = match + main_node = None + name_node = None + + for capture_name, nodes in captures_dict.items(): + if not nodes: + continue + node = nodes[0] + if capture_name in ("definition", "function_def", "class_def", + "method_def", "type_def", "const_def"): + main_node = node + elif capture_name in ("name", "function_name", "class_name", + "method_name", "type_name", "const_name"): + name_node = node + elif main_node is None: + main_node = node + + if main_node is None: + continue + + range_key = (main_node.start_byte, main_node.end_byte) + if range_key in seen_ranges: + continue + seen_ranges.add(range_key) + + # Extract name + if name_node: + name = content_bytes[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") + else: + name = self._extract_name_from_ts_node(main_node, content_bytes) + + # Infer kind from node type + kind = self._node_type_to_kind(main_node.type) + + # Extract docstring if available + docstring = self._extract_ts_docstring(main_node, content_bytes) + + # Extract signature + signature = self._extract_ts_signature(main_node, content_bytes, name, kind) + + # Extract decorators (for Python, etc.) + decorators = self._extract_ts_decorators(main_node, content_bytes) + + # Determine parent + parent = self._find_ts_parent_name(main_node, content_bytes) + + symbols.append(CodeSymbol( + name=name, + kind=kind, + start_line=main_node.start_point[0] + 1, + end_line=main_node.end_point[0] + 1, + path=f"{parent}.{name}" if parent else name, + docstring=docstring, + signature=signature, + decorators=decorators, + parent=parent, + )) + except Exception as e: + logger.debug(f"DEFINITION query failed for {language}: {e}") + + # Extract IMPORT concepts -> imports + import_query_str = mapping.get_query_for_concept(ConceptType.IMPORT) + if import_query_str: + try: + query = Query(ts_lang, import_query_str) + cursor = QueryCursor(query) + seen_ranges: Set[Tuple[int, int]] = set() + + for match in cursor.matches(root): + _, captures_dict = match + main_node = None + path_node = None + + for capture_name, nodes in captures_dict.items(): + if not nodes: + continue + node = nodes[0] + # Look for import path specifically + if capture_name in ("import_path", "path", "module", "source"): + path_node = node + # Look for import statement container + elif capture_name in ("import", "import_from", "import_statement", + "import_spec", "import_declaration", + "include", "require", "use", "definition"): + if main_node is None or node.start_byte < main_node.start_byte: + main_node = node + + # Use path_node if available for cleaner import text + import_node = path_node or main_node + if import_node is None: + continue + + range_key = (import_node.start_byte, import_node.end_byte) + if range_key in seen_ranges: + continue + seen_ranges.add(range_key) + + import_text = content_bytes[import_node.start_byte:import_node.end_byte].decode("utf-8", errors="replace") + module, names, is_from = self._parse_import_text(import_text, language) + + # If path_node was used directly, the text might be just the path + if not module and path_node: + module = import_text.strip().strip('"\'') + + if module: + imports.append(ImportReference( + module=module, + names=names, + line=import_node.start_point[0] + 1, + is_from=is_from, + )) + except Exception as e: + logger.debug(f"IMPORT query failed for {language}: {e}") + + # Extract calls by walking the tree for call expressions + calls = self._extract_calls_from_tree(root, content_bytes, symbols, language) + + # Extract all concepts for comprehensive analysis + concepts: List[ConceptUnit] = [] + for concept_type in ConceptType: + query_str = mapping.get_query_for_concept(concept_type) + if not query_str: + continue + try: + query = Query(ts_lang, query_str) + cursor = QueryCursor(query) + seen: Set[Tuple[int, int]] = set() + + for match in cursor.matches(root): + _, captures_dict = match + main_node = None + name_node = None + + for cname, nodes in captures_dict.items(): + if not nodes: + continue + node = nodes[0] + if cname in ("definition", "block", "import", "comment", "structure"): + main_node = node + elif cname == "name" or cname.endswith("_name"): + name_node = node + elif main_node is None: + main_node = node + + if main_node is None: + continue + + rkey = (main_node.start_byte, main_node.end_byte) + if rkey in seen: + continue + seen.add(rkey) + + if name_node: + name = content_bytes[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") + else: + name = self._extract_name_from_ts_node(main_node, content_bytes) + + unit_content = content_bytes[main_node.start_byte:main_node.end_byte].decode("utf-8", errors="replace") + + concepts.append(ConceptUnit( + concept=concept_type.value, + name=name, + content=unit_content, + start_line=main_node.start_point[0] + 1, + end_line=main_node.end_point[0] + 1, + kind=self._node_type_to_kind(main_node.type), + )) + except Exception as e: + logger.debug(f"{concept_type.value} query failed for {language}: {e}") + + return { + "symbols": symbols, + "imports": imports, + "calls": calls, + "concepts": concepts, # All semantic units by concept type + "language": language, + } + + def _normalize_lang(self, language: str) -> str: + """Normalize language name to tree-sitter key.""" + lang = language.lower().strip() + aliases = { + "js": "javascript", "jsx": "javascript", + "ts": "typescript", "tsx": "typescript", + "c++": "cpp", "cxx": "cpp", + "c#": "csharp", "cs": "csharp", + "shell": "bash", "sh": "bash", + } + return aliases.get(lang, lang) + + def _extract_name_from_ts_node(self, node, content_bytes: bytes) -> str: + """Extract name from tree-sitter node.""" + # Try field 'name' first + if hasattr(node, 'child_by_field_name'): + name_node = node.child_by_field_name('name') + if name_node: + return content_bytes[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") + + # Look for identifier child + for i in range(node.child_count): + child = node.child(i) + if child and child.type in ("identifier", "name", "type_identifier"): + return content_bytes[child.start_byte:child.end_byte].decode("utf-8", errors="replace") + + return f"anonymous_{node.start_point[0] + 1}" + + def _node_type_to_kind(self, node_type: str) -> str: + """Map tree-sitter node type to symbol kind.""" + mapping = { + # Functions + "function_definition": "function", + "async_function_definition": "function", + "function_declaration": "function", + "arrow_function": "function", + "function_item": "function", + "generator_function_declaration": "function", + # Methods + "method_definition": "method", + "method_declaration": "method", + # Classes + "class_definition": "class", + "class_declaration": "class", + "class_specifier": "class", + # Structs (Go, Rust, C/C++) + "struct_item": "struct", + "struct_specifier": "struct", + "type_declaration": "struct", # Go uses this for struct/interface + "type_spec": "struct", + # Interfaces + "interface_declaration": "interface", + "interface_type": "interface", + # Types + "type_alias_declaration": "type", + "type_item": "type", + # Enums + "enum_declaration": "enum", + "enum_item": "enum", + # Rust-specific + "impl_item": "impl", + "trait_item": "trait", + "mod_item": "module", + # Constants/Variables + "const_item": "constant", + "const_declaration": "constant", + "variable_declaration": "variable", + "lexical_declaration": "variable", + # Imports + "import_statement": "import", + "import_declaration": "import", + "import_spec": "import", + # Comments + "comment": "comment", + "block_comment": "comment", + "line_comment": "comment", + # Control flow (for BLOCK concepts) + "if_statement": "if", + "for_statement": "for", + "while_statement": "while", + "try_statement": "try", + "switch_statement": "switch", + "match_expression": "match", + } + return mapping.get(node_type, "symbol") + + def _extract_ts_docstring(self, node, content_bytes: bytes) -> Optional[str]: + """Extract docstring from node body.""" + body = node.child_by_field_name('body') if hasattr(node, 'child_by_field_name') else None + if not body: + return None + + for i in range(min(2, body.child_count)): + child = body.child(i) + if child and child.type == "expression_statement": + for j in range(child.child_count): + expr = child.child(j) + if expr and expr.type == "string": + text = content_bytes[expr.start_byte:expr.end_byte].decode("utf-8", errors="replace") + # Strip quotes + if text.startswith('"""') or text.startswith("'''"): + return text[3:-3].strip() + elif text.startswith('"') or text.startswith("'"): + return text[1:-1].strip() + return None + + def _extract_ts_signature(self, node, content_bytes: bytes, name: str, kind: str) -> str: + """Build signature from node.""" + if kind in ("function", "method"): + params_node = node.child_by_field_name('parameters') if hasattr(node, 'child_by_field_name') else None + if params_node: + params_text = content_bytes[params_node.start_byte:params_node.end_byte].decode("utf-8", errors="replace") + return f"def {name}{params_text}" + return f"def {name}()" + elif kind == "class": + return f"class {name}" + return name + + def _extract_ts_decorators(self, node, content_bytes: bytes) -> List[str]: + """Extract decorators from preceding siblings.""" + decorators = [] + prev = node.prev_sibling + while prev and prev.type == "decorator": + dec_text = content_bytes[prev.start_byte:prev.end_byte].decode("utf-8", errors="replace") + dec_name = dec_text.lstrip("@").split("(")[0] + decorators.insert(0, dec_name) + prev = prev.prev_sibling + return decorators + + def _find_ts_parent_name(self, node, content_bytes: bytes) -> Optional[str]: + """Find parent class/module name.""" + parent = node.parent + while parent: + if parent.type in ("class_definition", "class_declaration", "class_specifier", + "impl_item", "module"): + name_node = parent.child_by_field_name('name') if hasattr(parent, 'child_by_field_name') else None + if name_node: + return content_bytes[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") + parent = parent.parent + return None + + def _parse_import_text(self, text: str, language: str) -> Tuple[str, List[str], bool]: + """Parse import statement text to extract module and names.""" + text = text.strip() + + # Python: from X import Y or import X + if language == "python": + if text.startswith("from "): + match = re.match(r"from\s+([\w.]+)\s+import\s+(.+)", text) + if match: + module = match.group(1) + names_str = match.group(2) + names = [n.strip().split(" as ")[0] for n in names_str.split(",")] + return module, names, True + elif text.startswith("import "): + modules_part = text[7:].strip() + modules = [m.strip().split(" as ")[0].strip() for m in modules_part.split(",")] + if modules: + return modules[0], modules[1:] if len(modules) > 1 else [], False + + # JavaScript/TypeScript: import X from 'Y' or require('Y') + elif language in ("javascript", "typescript", "jsx", "tsx"): + if "from" in text: + match = re.search(r"from\s+['\"]([^'\"]+)['\"]", text) + if match: + return match.group(1), [], True + elif "require" in text: + match = re.search(r"require\s*\(\s*['\"]([^'\"]+)['\"]", text) + if match: + return match.group(1), [], False + + # Go: import "path" + elif language == "go": + match = re.search(r'"([^"]+)"', text) + if match: + return match.group(1), [], False + + # Rust: use path::to::module + elif language == "rust": + match = re.match(r"use\s+([\w:]+)", text) + if match: + return match.group(1), [], False + + # Java/Kotlin: import package.Class; + elif language in ("java", "kotlin"): + match = re.match(r"import\s+([\w.]+);?", text) + if match: + return match.group(1), [], False + + # C/C++: #include
or #include "header" + elif language in ("c", "cpp"): + match = re.search(r'#include\s*[<"]([^>"]+)[>"]', text) + if match: + return match.group(1), [], False + + # C#: using Namespace; + elif language == "csharp": + match = re.match(r"using\s+([\w.]+);?", text) + if match: + return match.group(1), [], False + + # Generic: try to find quoted string + match = re.search(r"['\"]([^'\"]+)['\"]", text) + if match: + return match.group(1), [], False + + return "", [], False + + def _extract_calls_from_tree(self, root, content_bytes: bytes, symbols: List[CodeSymbol], language: str) -> List[CallReference]: + """Walk tree to extract function calls.""" + calls: List[CallReference] = [] + # Only include functions/methods/classes as valid callers - NOT assignments/constants + # This prevents calls like `result = foo()` from being attributed to `result` instead of + # the enclosing function + valid_caller_kinds = {"function", "method", "class", "async_function", "module"} + symbol_ranges = [ + (s.start_line, s.end_line, s.path or s.name) + for s in symbols + if s.kind in valid_caller_kinds + ] + + def find_enclosing_symbol(line: int) -> str: + best_match = "" + best_span = float("inf") + for start, end, path in symbol_ranges: + if start <= line <= end: + span = end - start + if span < best_span: + best_span = span + best_match = path + return best_match + + def walk(node): + node_type = node.type + + # Call expressions + if node_type in ("call", "call_expression", "function_call", "method_call"): + func_node = node.child_by_field_name('function') if hasattr(node, 'child_by_field_name') else None + if not func_node: + # Try first child + for i in range(node.child_count): + child = node.child(i) + if child and child.type in ("identifier", "member_expression", "attribute"): + func_node = child + break + + if func_node: + callee = content_bytes[func_node.start_byte:func_node.end_byte].decode("utf-8", errors="replace") + # Clean up callee (get last part of attribute access) + if "." in callee: + callee = callee.split(".")[-1] + + line = node.start_point[0] + 1 + caller = find_enclosing_symbol(line) + + calls.append(CallReference( + caller=caller, + callee=callee, + line=line, + context="call", + )) + + # Recurse + for i in range(node.child_count): + child = node.child(i) + if child: + walk(child) + + walk(root) + return calls + # ---- Python-specific analysis (using ast module) ---- def _analyze_python(self, content: str, file_path: str) -> Dict[str, Any]: diff --git a/scripts/ctx.py b/scripts/ctx.py index 07ae7496..f6a9c6a4 100755 --- a/scripts/ctx.py +++ b/scripts/ctx.py @@ -1070,7 +1070,34 @@ def fetch_context(query: str, **filters) -> Tuple[str, str]: sys.stderr.flush() return "", "Context retrieval returned no data." - hits = data.get("results") or [] + def _extract_hits(payload: dict, *, label: str) -> list: + """Extract hits from MCP response, handling TOON-encoded strings. + + Prefers results_json when available, falls back to results, and + attempts TOON decode if the value is a string. + """ + hits_val = payload.get("results_json") or payload.get("results") or [] + if isinstance(hits_val, str): + raw_hits = hits_val + # Prefer preserved structured results_json when present + if isinstance(payload.get("results_json"), list): + hits_val = payload["results_json"] + else: + try: + from toon import decode as toon_decode # type: ignore[import-untyped] + decoded = toon_decode(raw_hits) + decoded_hits = decoded.get("results") or [] + hits_val = decoded_hits if isinstance(decoded_hits, list) else [] + except Exception: + sys.stderr.write( + f"[DEBUG] {label} returned TOON-formatted string results but could not decode; " + "treating as no results. Hint: install 'toon' or set TOON_ENABLED=0.\n" + ) + sys.stderr.flush() + hits_val = [] + return hits_val if isinstance(hits_val, list) else [] + + hits = _extract_hits(data, label="repo_search") relevance = _estimate_query_result_relevance(query, hits) sys.stderr.write(f"[DEBUG] repo_search returned {len(hits)} hits (relevance={relevance:.3f})\n") sys.stderr.flush() @@ -1120,7 +1147,7 @@ def fetch_context(query: str, **filters) -> Tuple[str, str]: if "error" not in memory_result: memory_data = parse_mcp_response(memory_result) if memory_data: - memory_hits = memory_data.get("results") or [] + memory_hits = _extract_hits(memory_data, label="context_search") if memory_hits: return format_search_results(memory_hits, include_snippets=with_snippets), "Using memories and design docs" return "", "No relevant context found for the prompt." diff --git a/scripts/ctx_cli/commands/init.py b/scripts/ctx_cli/commands/init.py index 473f2c2a..c3d7d7fa 100644 --- a/scripts/ctx_cli/commands/init.py +++ b/scripts/ctx_cli/commands/init.py @@ -651,9 +651,6 @@ def configure_env_file(skip_if_exists: bool = False) -> bool: NEO4J_URI=bolt://neo4j:7687 NEO4J_USER=neo4j NEO4J_PASSWORD=contextengine - -# Symbol graph -SYMBOL_GRAPH_ENABLED=1 """ # Add API keys if configured diff --git a/scripts/ctx_cli/commands/reset.py b/scripts/ctx_cli/commands/reset.py index 9cbdde88..3e8642cf 100644 --- a/scripts/ctx_cli/commands/reset.py +++ b/scripts/ctx_cli/commands/reset.py @@ -344,13 +344,13 @@ def reset( # Build env vars for indexer indexer_env = {} - for var in ["INDEX_MICRO_CHUNKS", "MAX_MICRO_CHUNKS_PER_FILE", "TOKENIZER_PATH", "TOKENIZER_URL"]: + for var in ["INDEX_MICRO_CHUNKS", "MAX_MICRO_CHUNKS_PER_FILE", "TOKENIZER_PATH", "TOKENIZER_URL", "INDEX_WORKERS"]: if var in os.environ: indexer_env[var] = os.environ[var] - # Defer pseudo-describe to backfill worker for much faster initial indexing - # The watch_index worker will backfill pseudo/tags after indexing completes indexer_env["PSEUDO_DEFER_TO_WORKER"] = "1" + if "INDEX_WORKERS" not in indexer_env: + indexer_env["INDEX_WORKERS"] = "4" # Run indexer detached (-d) so CLI doesn't block # Use --rm to auto-remove container on exit; first remove any stale container with same name diff --git a/scripts/ctx_cli/commands/search.py b/scripts/ctx_cli/commands/search.py index 38eafd04..0f63975a 100755 --- a/scripts/ctx_cli/commands/search.py +++ b/scripts/ctx_cli/commands/search.py @@ -299,9 +299,27 @@ def search_command( print(json.dumps(data, indent=2)) return 0 - # Extract results + # Extract results - handle TOON format if present results = data.get("results", []) - total = data.get("total", len(results)) + + # If results is a TOON string, try to decode it or use results_json fallback + if isinstance(results, str): + # First try results_json (preserved by server for internal callers) + if "results_json" in data and isinstance(data["results_json"], list): + results = data["results_json"] + else: + # Try to decode TOON string + try: + from toon import decode as toon_decode + decoded = toon_decode(results) + results = decoded.get("results", []) + except Exception: + # If TOON decode fails, return error + print("Error: Received TOON-formatted results but could not decode", file=sys.stderr) + print("Hint: Install toon package or set TOON_ENABLED=0", file=sys.stderr) + return 1 + + total = data.get("total", len(results) if isinstance(results, list) else 0) # Handle no results if not results: diff --git a/scripts/embedding_provider.py b/scripts/embedding_provider.py new file mode 100644 index 00000000..6a59f987 --- /dev/null +++ b/scripts/embedding_provider.py @@ -0,0 +1,303 @@ +"""Protocol-based embedding provider interface. + +Defines the abstract interface for embedding implementations, enabling +pluggable backends (FastEmbed, OpenAI, local models, etc.) while maintaining +consistent behavior and type safety. + +Usage: + class FastEmbedProvider: + '''Implements EmbeddingProvider protocol.''' + + @property + def name(self) -> str: + return "fastembed" + + async def embed(self, texts: list[str]) -> list[list[float]]: + return list(self.model.embed(texts)) + + # Manager for multiple providers + manager = EmbeddingManager() + manager.register_provider(FastEmbedProvider(), set_default=True) + + embeddings = await manager.embed(["hello", "world"]) +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import ( + Any, + AsyncIterator, + Dict, + List, + Optional, + Protocol, + runtime_checkable, +) + + +@dataclass +class RerankResult: + """Result from reranking operation.""" + index: int + score: float + text: Optional[str] = None + + +@dataclass +class EmbeddingConfig: + """Configuration for embedding providers.""" + provider: str + model: str + dims: int + distance: str = "cosine" + batch_size: int = 100 + max_tokens: Optional[int] = None + api_key: Optional[str] = None + base_url: Optional[str] = None + timeout: int = 30 + retry_attempts: int = 3 + retry_delay: float = 1.0 + + +@runtime_checkable +class EmbeddingProvider(Protocol): + """Abstract protocol for embedding providers. + + All embedding implementations must follow this interface. + Enables pluggable backends (OpenAI, local models, etc.) + """ + + @property + def name(self) -> str: + """Provider name (e.g., 'fastembed', 'openai').""" + ... + + @property + def model(self) -> str: + """Model name (e.g., 'BAAI/bge-base-en-v1.5').""" + ... + + @property + def dims(self) -> int: + """Embedding dimensions.""" + ... + + @property + def distance(self) -> str: + """Distance metric ('cosine' | 'l2' | 'ip').""" + ... + + @property + def batch_size(self) -> int: + """Maximum batch size for embedding requests.""" + ... + + async def embed(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for a list of texts.""" + ... + + async def embed_single(self, text: str) -> List[float]: + """Generate embedding for a single text.""" + ... + + async def embed_batch( + self, texts: List[str], batch_size: Optional[int] = None + ) -> List[List[float]]: + """Generate embeddings in batches for optimal performance.""" + ... + + def is_available(self) -> bool: + """Check if the provider is available and properly configured.""" + ... + + def get_optimal_batch_size(self) -> int: + """Get optimal batch size for this provider.""" + ... + + def get_max_tokens_per_batch(self) -> int: + """Get maximum tokens per batch for this provider.""" + ... + + def get_recommended_concurrency(self) -> int: + """Get recommended concurrent batch count based on provider's rate limits.""" + ... + + def supports_reranking(self) -> bool: + """Return True if this provider supports reranking.""" + ... + + +class BaseEmbeddingProvider(ABC): + """Base class with default implementations for common operations.""" + + def __init__(self, config: EmbeddingConfig): + self._config = config + self._total_tokens = 0 + self._total_requests = 0 + + @property + def config(self) -> EmbeddingConfig: + return self._config + + @property + def name(self) -> str: + return self._config.provider + + @property + def model(self) -> str: + return self._config.model + + @property + def dims(self) -> int: + return self._config.dims + + @property + def distance(self) -> str: + return self._config.distance + + @property + def batch_size(self) -> int: + return self._config.batch_size + + @abstractmethod + async def embed(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for a list of texts.""" + ... + + async def embed_single(self, text: str) -> List[float]: + """Generate embedding for a single text.""" + results = await self.embed([text]) + return results[0] + + async def embed_batch( + self, texts: List[str], batch_size: Optional[int] = None + ) -> List[List[float]]: + """Generate embeddings in batches.""" + effective_batch_size = batch_size or self.batch_size + all_embeddings: List[List[float]] = [] + + for i in range(0, len(texts), effective_batch_size): + batch = texts[i : i + effective_batch_size] + embeddings = await self.embed(batch) + all_embeddings.extend(embeddings) + + return all_embeddings + + async def embed_streaming(self, texts: List[str]) -> AsyncIterator[List[float]]: + """Generate embeddings with streaming results.""" + for text in texts: + embedding = await self.embed_single(text) + yield embedding + + def is_available(self) -> bool: + """Check if provider is available.""" + return True + + def get_optimal_batch_size(self) -> int: + return self.batch_size + + def get_max_tokens_per_batch(self) -> int: + return self._config.max_tokens or 8192 + + def get_max_documents_per_batch(self) -> int: + return self.batch_size + + def get_recommended_concurrency(self) -> int: + return 4 + + def get_max_rerank_batch_size(self) -> int: + return 64 + + def supports_reranking(self) -> bool: + return False + + async def rerank( + self, query: str, documents: List[str], top_k: Optional[int] = None + ) -> List[RerankResult]: + raise NotImplementedError("Reranking not supported by this provider") + + def estimate_tokens(self, text: str) -> int: + """Rough token estimation (chars / 4).""" + return len(text) // 4 + + def get_usage_stats(self) -> Dict[str, Any]: + return { + "total_tokens": self._total_tokens, + "total_requests": self._total_requests, + } + + def reset_usage_stats(self) -> None: + self._total_tokens = 0 + self._total_requests = 0 + + +@dataclass +class EmbeddingManager: + """Manager for multiple embedding providers. + + Centralizes provider registration and selection. + """ + + _providers: Dict[str, EmbeddingProvider] = field(default_factory=dict) + _default_provider: Optional[str] = None + + def register_provider( + self, provider: EmbeddingProvider, set_default: bool = False + ) -> None: + """Register an embedding provider.""" + self._providers[provider.name] = provider + if set_default or self._default_provider is None: + self._default_provider = provider.name + + def get_provider(self, name: Optional[str] = None) -> EmbeddingProvider: + """Get a provider by name or the default provider.""" + provider_name = name or self._default_provider + if provider_name is None: + raise ValueError("No default provider set and no name provided") + if provider_name not in self._providers: + raise ValueError(f"Provider '{provider_name}' not registered") + return self._providers[provider_name] + + def list_providers(self) -> List[str]: + """List registered provider names.""" + return list(self._providers.keys()) + + async def embed( + self, texts: List[str], provider_name: Optional[str] = None + ) -> List[List[float]]: + """Generate embeddings using specified or default provider.""" + provider = self.get_provider(provider_name) + return await provider.embed(texts) + + async def embed_batch( + self, + texts: List[str], + batch_size: Optional[int] = None, + provider_name: Optional[str] = None, + ) -> List[List[float]]: + """Generate embeddings in batches.""" + provider = self.get_provider(provider_name) + return await provider.embed_batch(texts, batch_size) + + +_default_manager: Optional[EmbeddingManager] = None + + +def get_embedding_manager() -> EmbeddingManager: + """Get or create the default embedding manager.""" + global _default_manager + if _default_manager is None: + _default_manager = EmbeddingManager() + return _default_manager + + +__all__ = [ + "RerankResult", + "EmbeddingConfig", + "EmbeddingProvider", + "BaseEmbeddingProvider", + "EmbeddingManager", + "get_embedding_manager", +] diff --git a/scripts/exceptions.py b/scripts/exceptions.py new file mode 100644 index 00000000..c48f603b --- /dev/null +++ b/scripts/exceptions.py @@ -0,0 +1,273 @@ +"""Structured exception hierarchy for Context-Engine. + +Provides a clear taxonomy of errors for better error handling, +logging, and debugging. All exceptions include context about +what operation failed and why. + +Usage: + try: + parse_file(path) + except ParsingError as e: + logger.error(f"Failed to parse {e.file_path}: {e}") + except ContextEngineError as e: + logger.error(f"Operation failed: {e}") +""" + +from __future__ import annotations + +from typing import Optional, Any, Dict + + +class ContextEngineError(Exception): + """Base exception for all Context-Engine errors.""" + + def __init__(self, message: str, context: Optional[Dict[str, Any]] = None): + super().__init__(message) + self.message = message + self.context = context or {} + + def __str__(self) -> str: + if self.context: + ctx_str = ", ".join(f"{k}={v}" for k, v in self.context.items()) + return f"{self.message} [{ctx_str}]" + return self.message + + +class ValidationError(ContextEngineError): + """Invalid input or configuration.""" + + def __init__(self, message: str, field: Optional[str] = None, value: Any = None): + context = {} + if field: + context["field"] = field + if value is not None: + context["value"] = repr(value)[:100] + super().__init__(message, context) + self.field = field + self.value = value + + +class ParsingError(ContextEngineError): + """Failed to parse a file or content.""" + + def __init__( + self, + message: str, + file_path: Optional[str] = None, + language: Optional[str] = None, + line: Optional[int] = None, + ): + context = {} + if file_path: + context["file"] = file_path + if language: + context["language"] = language + if line is not None: + context["line"] = line + super().__init__(message, context) + self.file_path = file_path + self.language = language + self.line = line + + +class ChunkingError(ContextEngineError): + """Failed to chunk content.""" + + def __init__( + self, + message: str, + file_path: Optional[str] = None, + chunk_index: Optional[int] = None, + ): + context = {} + if file_path: + context["file"] = file_path + if chunk_index is not None: + context["chunk_index"] = chunk_index + super().__init__(message, context) + self.file_path = file_path + self.chunk_index = chunk_index + + +class EmbeddingError(ContextEngineError): + """Failed to generate embeddings.""" + + def __init__( + self, + message: str, + provider: Optional[str] = None, + model: Optional[str] = None, + batch_size: Optional[int] = None, + ): + context = {} + if provider: + context["provider"] = provider + if model: + context["model"] = model + if batch_size is not None: + context["batch_size"] = batch_size + super().__init__(message, context) + self.provider = provider + self.model = model + self.batch_size = batch_size + + +class IndexingError(ContextEngineError): + """Failed to index content into vector database.""" + + def __init__( + self, + message: str, + collection: Optional[str] = None, + file_path: Optional[str] = None, + point_count: Optional[int] = None, + ): + context = {} + if collection: + context["collection"] = collection + if file_path: + context["file"] = file_path + if point_count is not None: + context["points"] = point_count + super().__init__(message, context) + self.collection = collection + self.file_path = file_path + self.point_count = point_count + + +class DatabaseError(ContextEngineError): + """Database operation failed.""" + + def __init__( + self, + message: str, + operation: Optional[str] = None, + collection: Optional[str] = None, + ): + context = {} + if operation: + context["operation"] = operation + if collection: + context["collection"] = collection + super().__init__(message, context) + self.operation = operation + self.collection = collection + + +class SearchError(ContextEngineError): + """Search operation failed.""" + + def __init__( + self, + message: str, + query: Optional[str] = None, + collection: Optional[str] = None, + ): + context = {} + if query: + context["query"] = query[:100] + if collection: + context["collection"] = collection + super().__init__(message, context) + self.query = query + self.collection = collection + + +class ConfigurationError(ContextEngineError): + """Invalid or missing configuration.""" + + def __init__(self, message: str, config_key: Optional[str] = None): + context = {} + if config_key: + context["key"] = config_key + super().__init__(message, context) + self.config_key = config_key + + +class ProviderError(ContextEngineError): + """External provider/service error.""" + + def __init__( + self, + message: str, + provider: Optional[str] = None, + status_code: Optional[int] = None, + ): + context = {} + if provider: + context["provider"] = provider + if status_code is not None: + context["status"] = status_code + super().__init__(message, context) + self.provider = provider + self.status_code = status_code + + +class CacheError(ContextEngineError): + """Cache operation failed.""" + + def __init__( + self, + message: str, + cache_type: Optional[str] = None, + key: Optional[str] = None, + ): + context = {} + if cache_type: + context["cache"] = cache_type + if key: + context["key"] = key[:100] + super().__init__(message, context) + self.cache_type = cache_type + self.key = key + + +class RateLimitError(ProviderError): + """Rate limit exceeded.""" + + def __init__( + self, + message: str, + provider: Optional[str] = None, + retry_after: Optional[float] = None, + ): + super().__init__(message, provider, status_code=429) + self.retry_after = retry_after + if retry_after is not None: + self.context["retry_after"] = retry_after + + +class TimeoutError(ContextEngineError): + """Operation timed out.""" + + def __init__( + self, + message: str, + operation: Optional[str] = None, + timeout_seconds: Optional[float] = None, + ): + context = {} + if operation: + context["operation"] = operation + if timeout_seconds is not None: + context["timeout"] = timeout_seconds + super().__init__(message, context) + self.operation = operation + self.timeout_seconds = timeout_seconds + + +__all__ = [ + "ContextEngineError", + "ValidationError", + "ParsingError", + "ChunkingError", + "EmbeddingError", + "IndexingError", + "DatabaseError", + "SearchError", + "ConfigurationError", + "ProviderError", + "CacheError", + "RateLimitError", + "TimeoutError", +] diff --git a/scripts/hybrid/elbow_detection.py b/scripts/hybrid/elbow_detection.py new file mode 100644 index 00000000..cadc8d84 --- /dev/null +++ b/scripts/hybrid/elbow_detection.py @@ -0,0 +1,363 @@ +"""Elbow detection for adaptive threshold computation. + +Mathematical approaches: +1. Curvature-based detection - finds point of maximum bending (2nd derivative) +2. Multi-changepoint detection - finds multiple quality tiers via recursive segmentation +3. Kneedle fallback - perpendicular distance method for edge cases + +Curvature formula: κ(i) = |f''(i)| / (1 + f'(i)²)^(3/2) +where f'(i) and f''(i) are discrete derivatives using central differences. +""" + +from __future__ import annotations + +import logging +from typing import Sequence, Union, List, Tuple + +import numpy as np + +logger = logging.getLogger(__name__) + + +def _discrete_curvature(y: np.ndarray) -> np.ndarray: + """Compute discrete curvature using central differences. + + κ(i) = |y''(i)| / (1 + y'(i)²)^(3/2) + + First derivative: y'(i) = (y[i+1] - y[i-1]) / 2 + Second derivative: y''(i) = y[i+1] - 2*y[i] + y[i-1] + """ + n = len(y) + if n < 3: + return np.zeros(n) + + curvature = np.zeros(n) + + for i in range(1, n - 1): + y_prime = (y[i + 1] - y[i - 1]) / 2.0 + y_double_prime = y[i + 1] - 2.0 * y[i] + y[i - 1] + + denominator = (1.0 + y_prime ** 2) ** 1.5 + if denominator > 1e-10: + curvature[i] = abs(y_double_prime) / denominator + + return curvature + + +def _segment_cost(y: np.ndarray) -> float: + """Compute segment cost as negative log-likelihood under Gaussian model. + + Cost = n * log(variance) where variance = Σ(y - mean)² / n + Lower cost = more homogeneous segment. + """ + if len(y) < 2: + return 0.0 + variance = np.var(y) + if variance < 1e-10: + return 0.0 + return len(y) * np.log(variance) + + +def find_elbow_curvature(sorted_scores: Sequence[float]) -> int | None: + """Find elbow using maximum curvature (2nd derivative method). + + More mathematically rigorous than perpendicular distance: + - Curvature measures local bending intensity + - Invariant to linear transformation of axes + - Maximum curvature = point of diminishing returns + + Args: + sorted_scores: Scores sorted DESCENDING + + Returns: + Index of elbow point, or None if no significant elbow + """ + if len(sorted_scores) < 4: + return None + + scores = np.array(sorted_scores, dtype=np.float64) + + min_s, max_s = scores.min(), scores.max() + if max_s - min_s < 1e-10: + return None + + normalized = (scores - min_s) / (max_s - min_s) + + x = np.linspace(0, 1, len(normalized)) + + curvature = _discrete_curvature(normalized) + + search_start = 1 + search_end = len(curvature) - 1 + if search_end <= search_start: + return None + + max_idx = search_start + int(np.argmax(curvature[search_start:search_end])) + max_curvature = curvature[max_idx] + + if max_curvature < 0.1: + logger.debug(f"Curvature: No significant elbow (max_κ={max_curvature:.4f} < 0.1)") + return None + + logger.debug( + f"Curvature: Found elbow at index {max_idx} " + f"(κ={max_curvature:.4f}, score={sorted_scores[max_idx]:.3f})" + ) + return max_idx + + +def find_changepoints( + sorted_scores: Sequence[float], + max_changepoints: int = 3, + min_segment_size: int = 2, +) -> List[int]: + """Find multiple changepoints using recursive binary segmentation. + + Uses BIC penalty: β = log(n) to prevent overfitting. + + Args: + sorted_scores: Scores sorted DESCENDING + max_changepoints: Maximum number of changepoints to find + min_segment_size: Minimum segment size + + Returns: + List of changepoint indices (sorted), empty if none found + """ + if len(sorted_scores) < 2 * min_segment_size: + return [] + + scores = np.array(sorted_scores, dtype=np.float64) + n = len(scores) + + penalty = np.log(n) + + def find_best_split(start: int, end: int) -> Tuple[int, float]: + """Find best split point in segment [start, end).""" + if end - start < 2 * min_segment_size: + return -1, 0.0 + + segment = scores[start:end] + base_cost = _segment_cost(segment) + + best_idx = -1 + best_gain = 0.0 + + for split in range(start + min_segment_size, end - min_segment_size + 1): + left_cost = _segment_cost(scores[start:split]) + right_cost = _segment_cost(scores[split:end]) + + gain = base_cost - (left_cost + right_cost) - penalty + + if gain > best_gain: + best_gain = gain + best_idx = split + + return best_idx, best_gain + + changepoints = [] + segments = [(0, n)] + + while len(changepoints) < max_changepoints and segments: + best_segment_idx = -1 + best_split = -1 + best_gain = 0.0 + + for seg_idx, (start, end) in enumerate(segments): + split, gain = find_best_split(start, end) + if gain > best_gain: + best_gain = gain + best_split = split + best_segment_idx = seg_idx + + if best_split == -1: + break + + changepoints.append(best_split) + + start, end = segments.pop(best_segment_idx) + segments.append((start, best_split)) + segments.append((best_split, end)) + + return sorted(changepoints) + + +def find_elbow_kneedle(sorted_scores: Sequence[float]) -> int | None: + """Find elbow using perpendicular distance (Kneedle algorithm). + + Fallback method when curvature-based detection fails. + """ + if len(sorted_scores) < 3: + return None + + scores = np.array(sorted_scores, dtype=np.float64) + + min_score, max_score = scores.min(), scores.max() + if max_score - min_score < 1e-10: + return None + + normalized = (scores - min_score) / (max_score - min_score) + x = np.linspace(0, 1, len(normalized)) + + x1, y1 = x[0], normalized[0] + x2, y2 = x[-1], normalized[-1] + + if abs(x2 - x1) < 1e-10: + return None + + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + + numerator = np.abs(m * x - normalized + b) + denominator = np.sqrt(m ** 2 + 1) + distances = numerator / denominator + + elbow_idx = int(np.argmax(distances)) + + if distances[elbow_idx] < 0.01: + return None + + return elbow_idx + + +def compute_elbow_threshold( + chunks_or_scores: Union[Sequence[dict], Sequence[float]], + score_key: str = "score", + fallback_score_key: str = "rerank_score", + method: str = "curvature", +) -> float: + """Compute elbow threshold using specified method. + + Args: + chunks_or_scores: List of chunks (dicts) or raw scores + score_key: Primary score key for dicts + fallback_score_key: Fallback score key + method: "curvature" (default), "kneedle", or "changepoint" + + Returns: + Threshold value at elbow point + """ + if not chunks_or_scores: + return 0.5 + + if isinstance(chunks_or_scores[0], dict): + chunk_list: Sequence[dict] = chunks_or_scores # type: ignore + scores = [] + for c in chunk_list: + score = c.get(score_key) + if score is None: + score = c.get(fallback_score_key, 0.0) + scores.append(float(score)) + else: + scores = [float(s) for s in chunks_or_scores] + + if not scores: + return 0.5 + + sorted_scores = sorted(scores, reverse=True) + + elbow_idx = None + + if method == "curvature": + elbow_idx = find_elbow_curvature(sorted_scores) + if elbow_idx is None: + elbow_idx = find_elbow_kneedle(sorted_scores) + elif method == "changepoint": + changepoints = find_changepoints(sorted_scores, max_changepoints=1) + if changepoints: + elbow_idx = changepoints[0] + else: + elbow_idx = find_elbow_kneedle(sorted_scores) + + if elbow_idx is not None and 0 <= elbow_idx < len(sorted_scores): + return float(sorted_scores[elbow_idx]) + + median_idx = len(sorted_scores) // 2 + return float(sorted_scores[median_idx]) + + +def compute_tier_thresholds( + chunks_or_scores: Union[Sequence[dict], Sequence[float]], + score_key: str = "score", + fallback_score_key: str = "rerank_score", + max_tiers: int = 3, +) -> List[float]: + """Compute multiple quality tier thresholds. + + Uses changepoint detection to find natural breaks in score distribution. + + Args: + chunks_or_scores: List of chunks or raw scores + score_key: Primary score key + fallback_score_key: Fallback score key + max_tiers: Maximum number of tiers (changepoints + 1) + + Returns: + List of threshold values (descending), one per tier boundary + """ + if not chunks_or_scores: + return [] + + if isinstance(chunks_or_scores[0], dict): + chunk_list: Sequence[dict] = chunks_or_scores # type: ignore + scores = [] + for c in chunk_list: + score = c.get(score_key) + if score is None: + score = c.get(fallback_score_key, 0.0) + scores.append(float(score)) + else: + scores = [float(s) for s in chunks_or_scores] + + if not scores: + return [] + + sorted_scores = sorted(scores, reverse=True) + + changepoints = find_changepoints( + sorted_scores, + max_changepoints=max_tiers - 1, + min_segment_size=max(2, len(sorted_scores) // 10) + ) + + return [float(sorted_scores[cp]) for cp in changepoints] + + +def filter_by_elbow( + results: Sequence[dict], + score_key: str = "score", + fallback_score_key: str = "rerank_score", + min_results: int = 1, + method: str = "curvature", +) -> list[dict]: + """Filter results using elbow detection. + + Args: + results: List of result dicts + score_key: Primary score key + fallback_score_key: Fallback score key + min_results: Minimum results to return + method: Detection method ("curvature", "kneedle", "changepoint") + + Returns: + Filtered results above elbow threshold + """ + if not results: + return [] + + threshold = compute_elbow_threshold( + results, score_key, fallback_score_key, method + ) + + def get_score(r: dict) -> float: + score = r.get(score_key) + if score is None: + score = r.get(fallback_score_key, 0.0) + return float(score) + + filtered = [r for r in results if get_score(r) >= threshold] + + if len(filtered) < min_results and len(results) >= min_results: + sorted_results = sorted(results, key=get_score, reverse=True) + return sorted_results[:min_results] + + return filtered if filtered else results[:min_results] diff --git a/scripts/hybrid/qdrant.py b/scripts/hybrid/qdrant.py index 33925c47..039ca6f1 100644 --- a/scripts/hybrid/qdrant.py +++ b/scripts/hybrid/qdrant.py @@ -844,6 +844,105 @@ def multi_granular_query( # Module exports # --------------------------------------------------------------------------- +def find_similar_chunks( + client, + chunk_id: str, + collection: str, + vec_name: str, + limit: int = 20, + threshold: float | None = None, + path_filter: str | None = None, +) -> List[Dict[str, Any]]: + """Find chunks similar to a given chunk by retrieving its vector and searching. + + Used for multi-hop search expansion - given a high-scoring chunk, + find its nearest neighbors in the vector space. + + Args: + client: QdrantClient instance + chunk_id: ID of the chunk to find similar chunks for + collection: Collection name + vec_name: Vector name to use for similarity + limit: Maximum number of similar chunks to return + threshold: Optional minimum similarity score + path_filter: Optional path prefix to filter results + + Returns: + List of similar chunks with score, content, path, and full payload + """ + try: + points = client.retrieve( + collection_name=collection, + ids=[chunk_id], + with_vectors=[vec_name], + ) + except Exception: + points = client.retrieve( + collection_name=collection, + ids=[chunk_id], + with_vectors=True, + ) + + if not points: + return [] + + point = points[0] + vector = point.vector + if isinstance(vector, dict): + vector = vector.get(vec_name) + if not vector: + return [] + + must_not = [models.HasIdCondition(has_id=[chunk_id])] + must = [] + + if path_filter: + path_filter_clean = path_filter.rstrip("/") + must.append(models.FieldCondition( + key="metadata.path_prefix", + match=models.MatchValue(value=path_filter_clean), + )) + + flt = models.Filter(must=must, must_not=must_not) if must or must_not else None + + try: + results = client.search( + collection_name=collection, + query_vector=(vec_name, vector), + query_filter=flt, + limit=limit, + score_threshold=threshold, + with_payload=True, + ) + except TypeError: + results = client.search( + collection_name=collection, + query_vector=vector, + query_filter=flt, + limit=limit, + score_threshold=threshold, + with_payload=True, + ) + + output = [] + for r in results: + md = (r.payload or {}).get("metadata", {}) + output.append({ + "chunk_id": str(r.id), + "score": r.score, + "similarity": r.score, + "content": md.get("text", ""), + "path": md.get("path", ""), + "start_line": md.get("start_line"), + "end_line": md.get("end_line"), + "symbol": md.get("symbol"), + "kind": md.get("kind"), + "payload": r.payload, + }) + + return output + + __all__ = [ # Pool availability flag "_POOL_AVAILABLE", @@ -864,7 +963,7 @@ def multi_granular_query( "_get_client_endpoint", "_ensure_collection", "clear_ensured_collections", - # Collection name resolution + # Collection resolution "_collection", # Filter sanitization "_sanitize_filter_obj", @@ -877,6 +976,7 @@ def multi_granular_query( "sparse_lex_query", "dense_query", "multi_granular_query", + "find_similar_chunks", # Multi-granular config "MULTI_GRANULAR_VECTORS", "ENTITY_DENSE_NAME", @@ -890,3 +990,4 @@ def multi_granular_query( "LEX_SPARSE_MODE", "EF_SEARCH", ] + diff --git a/scripts/hybrid/termination.py b/scripts/hybrid/termination.py new file mode 100644 index 00000000..c5746b1d --- /dev/null +++ b/scripts/hybrid/termination.py @@ -0,0 +1,307 @@ +"""Smart termination for iterative search operations. + +Mathematical foundations: +1. Welford's algorithm - O(1) online variance for adaptive thresholds +2. Page-Hinkley test - detects mean shift in streaming data +3. Statistical termination - uses 2-sigma rule instead of fixed thresholds + +Welford's update: δ = x - μ, μ' = μ + δ/n, M2' = M2 + δ(x - μ') +Page-Hinkley: cumsum of (x - μ - δ), detect when max deviation exceeds threshold +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Sequence, Optional +import math + +logger = logging.getLogger(__name__) + + +@dataclass +class WelfordState: + """Online variance computation using Welford's algorithm.""" + n: int = 0 + mean: float = 0.0 + m2: float = 0.0 + + def update(self, x: float) -> None: + """O(1) update with new value.""" + self.n += 1 + delta = x - self.mean + self.mean += delta / self.n + delta2 = x - self.mean + self.m2 += delta * delta2 + + @property + def variance(self) -> float: + return self.m2 / self.n if self.n > 1 else 0.0 + + @property + def std(self) -> float: + return math.sqrt(self.variance) + + def adaptive_threshold(self, sigma_multiplier: float = 2.0) -> float: + """Return threshold as mean - sigma_multiplier * std.""" + return self.mean - sigma_multiplier * self.std + + +@dataclass +class PageHinkleyState: + """Page-Hinkley test for DOWNWARD mean shift detection (score degradation). + + Detects when scores drop significantly below the running mean. + Cumsum formula: cumsum += (mean - x + delta) + When x consistently falls below mean, cumsum grows and triggers detection. + + This is the inverse of the standard PH test (which detects upward drift). + Optimized for search relevance degradation detection. + """ + delta: float = 0.005 + threshold: float = 0.5 + n: int = 0 + mean: float = 0.0 + cumsum: float = 0.0 + cumsum_max: float = 0.0 + + def update(self, x: float) -> bool: + """Update and return True if downward drift detected.""" + self.n += 1 + + if self.n == 1: + self.mean = x + return False + + self.mean = ((self.n - 1) * self.mean + x) / self.n + + # cumsum += (mean - x + delta): grows when x < mean + self.cumsum += self.mean - x + self.delta + self.cumsum_max = max(self.cumsum_max, self.cumsum) + + if self.cumsum > self.threshold: + return True + + return False + + def reset(self) -> None: + self.n = 0 + self.mean = 0.0 + self.cumsum = 0.0 + self.cumsum_max = 0.0 + + +@dataclass +class TerminationConfig: + time_limit: float = 5.0 + result_limit: int = 500 + min_candidates_for_expansion: int = 5 + + use_adaptive_threshold: bool = True + sigma_multiplier: float = 2.0 + fixed_degradation_threshold: float = 0.15 + + use_page_hinkley: bool = True + page_hinkley_delta: float = 0.005 + page_hinkley_threshold: float = 0.5 + + min_relevance_score: float = 0.3 + top_n_to_track: int = 5 + + min_iterations_before_stop: int = 2 + + +class TerminationChecker: + """Statistically-grounded termination for iterative search.""" + + def __init__(self, config: TerminationConfig | None = None): + self.config = config or TerminationConfig() + self.start_time = time.perf_counter() + self.iteration = 0 + + self.tracked_chunk_scores: Dict[str, float] = {} + + self.score_stats = WelfordState() + self.page_hinkley = PageHinkleyState( + delta=self.config.page_hinkley_delta, + threshold=self.config.page_hinkley_threshold, + ) + + self.top_scores_history: List[float] = [] + + def reset(self) -> None: + self.start_time = time.perf_counter() + self.iteration = 0 + self.tracked_chunk_scores.clear() + self.score_stats = WelfordState() + self.page_hinkley.reset() + self.top_scores_history.clear() + + def elapsed(self) -> float: + return time.perf_counter() - self.start_time + + def check( + self, + results: Sequence[dict], + score_key: str = "score", + id_key: str = "chunk_id", + ) -> Tuple[bool, str]: + """Check termination conditions with statistical methods. + + Returns: + (should_terminate, reason) + """ + self.iteration += 1 + + if self.elapsed() >= self.config.time_limit: + logger.debug(f"Termination: time limit {self.config.time_limit}s") + return True, "time_limit" + + if len(results) >= self.config.result_limit: + logger.debug(f"Termination: result limit {self.config.result_limit}") + return True, "result_limit" + + def get_numeric_score(r: dict) -> float: + score = r.get(score_key, 0) + if score is None or isinstance(score, bool): + return 0.0 + try: + return float(score) + except (TypeError, ValueError): + return 0.0 + + high_scoring = [r for r in results if get_numeric_score(r) > 0] + if len(high_scoring) < self.config.min_candidates_for_expansion: + logger.debug(f"Termination: insufficient candidates ({len(high_scoring)})") + return True, "insufficient_candidates" + + sorted_results = sorted(results, key=lambda x: -get_numeric_score(x)) + top_n = sorted_results[:self.config.top_n_to_track] + + if top_n: + top_score = get_numeric_score(top_n[0]) + self.score_stats.update(top_score) + self.top_scores_history.append(top_score) + + if self.iteration >= self.config.min_iterations_before_stop: + + if self.config.use_page_hinkley and top_n: + top_score = get_numeric_score(top_n[0]) + if self.page_hinkley.update(top_score): + logger.debug("Termination: Page-Hinkley detected score drift") + return True, "score_drift_detected" + + if self.tracked_chunk_scores and self.iteration > 2: + if self.config.use_adaptive_threshold: + threshold = self.score_stats.adaptive_threshold( + self.config.sigma_multiplier + ) + if threshold <= 0: + threshold = self.config.fixed_degradation_threshold + else: + threshold = self.config.fixed_degradation_threshold + + max_drop = 0.0 + for chunk_id, prev_score in self.tracked_chunk_scores.items(): + current_score = next( + (get_numeric_score(r) for r in results if r.get(id_key) == chunk_id), + 0.0, + ) + if current_score < prev_score: + max_drop = max(max_drop, prev_score - current_score) + + if max_drop >= threshold: + logger.debug( + f"Termination: score degradation {max_drop:.3f} >= " + f"threshold {threshold:.3f}" + ) + return True, "score_degradation" + + self.tracked_chunk_scores.clear() + for r in top_n: + chunk_id = r.get(id_key) + if chunk_id: + self.tracked_chunk_scores[chunk_id] = get_numeric_score(r) + + if top_n: + min_score = min(get_numeric_score(r) for r in top_n) + if min_score < self.config.min_relevance_score: + logger.debug(f"Termination: min relevance {min_score:.3f}") + return True, "min_relevance" + + return False, "" + + def get_stats(self) -> Dict[str, float]: + return { + "iterations": self.iteration, + "elapsed_seconds": round(self.elapsed(), 3), + "tracked_chunks": len(self.tracked_chunk_scores), + "score_mean": round(self.score_stats.mean, 4), + "score_std": round(self.score_stats.std, 4), + "adaptive_threshold": round( + self.score_stats.adaptive_threshold(self.config.sigma_multiplier), 4 + ), + "page_hinkley_cumsum": round(self.page_hinkley.cumsum, 4), + } + + +def mann_whitney_u(x: Sequence[float], y: Sequence[float]) -> Tuple[float, float]: + """Mann-Whitney U test for comparing two score distributions. + + Returns (U statistic, approximate p-value using normal approximation). + Useful for comparing score quality across iterations. + """ + nx, ny = len(x), len(y) + if nx == 0 or ny == 0: + return 0.0, 1.0 + + combined = [(v, 0) for v in x] + [(v, 1) for v in y] + combined.sort(key=lambda t: t[0]) + + i = 0 + rank_list = [] + while i < len(combined): + j = i + while j < len(combined) and combined[j][0] == combined[i][0]: + j += 1 + avg_rank = (i + j + 1) / 2.0 + for k in range(i, j): + rank_list.append((combined[k][0], combined[k][1], avg_rank)) + i = j + + r1 = sum(rank for val, group, rank in rank_list if group == 0) + + u1 = r1 - nx * (nx + 1) / 2 + u2 = nx * ny - u1 + u = min(u1, u2) + + mu = nx * ny / 2 + sigma = math.sqrt(nx * ny * (nx + ny + 1) / 12) + + if sigma == 0: + return u, 1.0 + + z = (u - mu) / sigma + + p = 2 * (1 - _normal_cdf(abs(z))) + + return u, p + + +def _normal_cdf(x: float) -> float: + """Standard normal CDF approximation (Abramowitz & Stegun).""" + a1 = 0.254829592 + a2 = -0.284496736 + a3 = 1.421413741 + a4 = -1.453152027 + a5 = 1.061405429 + p = 0.3275911 + + sign = 1 if x >= 0 else -1 + x = abs(x) + + t = 1.0 / (1.0 + p * x) + y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * math.exp(-x * x / 2) + + return 0.5 * (1.0 + sign * y) diff --git a/scripts/hybrid_qdrant.py b/scripts/hybrid_qdrant.py index 2498ef8e..789700b9 100644 --- a/scripts/hybrid_qdrant.py +++ b/scripts/hybrid_qdrant.py @@ -1,3 +1,4 @@ #!/usr/bin/env python3 """Shim for backward compatibility. See scripts/hybrid/qdrant.py""" from scripts.hybrid.qdrant import * +from scripts.hybrid.qdrant import __all__ diff --git a/scripts/hybrid_search.py b/scripts/hybrid_search.py index 56805341..4fbd1e25 100644 --- a/scripts/hybrid_search.py +++ b/scripts/hybrid_search.py @@ -245,6 +245,23 @@ _IMPL_INTENT_PATTERNS, ) +# --------------------------------------------------------------------------- +# Elbow detection for adaptive filtering +# --------------------------------------------------------------------------- +# Lazy import to avoid hard numpy dependency when feature is disabled +_filter_by_elbow = None + +def _get_filter_by_elbow(): + """Lazy load filter_by_elbow to avoid numpy import when disabled.""" + global _filter_by_elbow + if _filter_by_elbow is None: + from scripts.hybrid.elbow_detection import filter_by_elbow + _filter_by_elbow = filter_by_elbow + return _filter_by_elbow + +# Environment variable for elbow filtering (opt-in) +ELBOW_FILTER_ENABLED = _env_truthy(os.environ.get("HYBRID_ELBOW_FILTER"), False) + # --------------------------------------------------------------------------- # Re-exports from hybrid_expand # --------------------------------------------------------------------------- @@ -3011,6 +3028,24 @@ def _resolve(seg: str) -> list[str]: if why is not None: item["why"] = why items.append(item) + + # Apply elbow detection filter if enabled (adaptive threshold based on score distribution) + if ELBOW_FILTER_ENABLED and items: + original_count = len(items) + # Use rerank_score if available, otherwise use score + items = _get_filter_by_elbow()( + items, + score_key="rerank_score", + fallback_score_key="score", + min_results=max(1, limit // 2) if limit > 0 else 0, # Keep at least half the requested limit + ) + if os.environ.get("DEBUG_HYBRID_SEARCH"): + logger.debug( + f"Elbow filter: {original_count} -> {len(items)} results " + f"(threshold based on curvature method)" + ) + _dt("elbow_filter") + if _USE_CACHE and cache_key is not None: if UNIFIED_CACHE_AVAILABLE: _RESULTS_CACHE.set(cache_key, items) diff --git a/scripts/indexing_admin.py b/scripts/indexing_admin.py index fd2389fd..1e8abf5f 100644 --- a/scripts/indexing_admin.py +++ b/scripts/indexing_admin.py @@ -1078,7 +1078,7 @@ def recreate_collection_qdrant(*, qdrant_url: str, api_key: Optional[str], colle # Also delete the graph collection if it exists # Graph collections are tightly coupled to their main collection - # The decision to recreate happens during ingest (based on INDEX_GRAPH_EDGES) + # Graph edges are always indexed (Qdrant flat graph is always on) if get_graph_collection_name_t is not None: graph_name = get_graph_collection_name_t(name) try: diff --git a/scripts/ingest/cast_chunker.py b/scripts/ingest/cast_chunker.py index 05155e4c..db8f0f8a 100644 --- a/scripts/ingest/cast_chunker.py +++ b/scripts/ingest/cast_chunker.py @@ -181,7 +181,7 @@ def _non_whitespace_chars(self, text: str) -> int: # Deduplication # ------------------------------------------------------------------------- def _deduplicate_chunks(self, chunks: List[SemanticChunk]) -> List[SemanticChunk]: - """Remove chunks with identical content, keeping most specific.""" + """Remove chunks with identical content, keeping most specific (legacy).""" if not self.config.deduplicate or not chunks: return chunks @@ -189,7 +189,6 @@ def _deduplicate_chunks(self, chunks: List[SemanticChunk]) -> List[SemanticChunk for chunk in chunks: key = chunk.content.strip() if key in seen_content: - # Keep the more specific one (DEFINITION > BLOCK > COMMENT) existing = seen_content[key] priority = {ConceptType.DEFINITION: 3, ConceptType.BLOCK: 2, ConceptType.COMMENT: 1, ConceptType.IMPORT: 2, @@ -201,6 +200,18 @@ def _deduplicate_chunks(self, chunks: List[SemanticChunk]) -> List[SemanticChunk return list(seen_content.values()) + def _deduplicate_chunks_v2( + self, chunks: List[SemanticChunk], language: str + ) -> List[SemanticChunk]: + """O(n log n) deduplication with substring detection.""" + if not self.config.deduplicate or not chunks: + return chunks + try: + from scripts.ingest.chunk_deduplication import deduplicate_semantic_chunks + return deduplicate_semantic_chunks(chunks, language) + except ImportError: + return self._deduplicate_chunks(chunks) + # ------------------------------------------------------------------------- # Merge Logic # ------------------------------------------------------------------------- @@ -604,8 +615,8 @@ def chunk( parent=None, )] - # Step 2: Deduplicate - chunks = self._deduplicate_chunks(chunks) + # Step 2: Deduplicate (O(n log n) with substring detection) + chunks = self._deduplicate_chunks_v2(chunks, language) # Step 3: Group by concept type by_concept: Dict[ConceptType, List[SemanticChunk]] = {} diff --git a/scripts/ingest/chunk_deduplication.py b/scripts/ingest/chunk_deduplication.py new file mode 100644 index 00000000..5eb7aaf9 --- /dev/null +++ b/scripts/ingest/chunk_deduplication.py @@ -0,0 +1,287 @@ +"""High-performance chunk deduplication with O(n log n) complexity. + +Two-stage deduplication: +1. Exact content matching via hash table (O(n)) +2. Substring detection via sorted interval scan (O(n log n)) + +Specificity scoring uses weighted formula: + score = w_type * type_weight + w_size * log(line_count) + w_name * has_name + +where: + - type_weight: structural importance (definition > block > comment) + - log(line_count): information content (more lines = more context) + - has_name: named symbols are more referenceable +""" + +from __future__ import annotations + +import logging +import math +from collections import defaultdict +from typing import Sequence, TypeVar, Dict, Any + +import xxhash + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=dict) + +TYPE_WEIGHTS: Dict[str, float] = { + "function": 1.0, + "method": 1.0, + "class": 1.0, + "interface": 1.0, + "struct": 1.0, + "enum": 1.0, + "definition": 1.0, + "type_alias": 0.8, + "type": 0.8, + "import": 0.6, + "comment": 0.4, + "docstring": 0.4, + "block": 0.3, + "array": 0.2, + "structure": 0.1, +} + +SPECIFICITY_WEIGHTS = { + "type": 0.5, + "size": 0.3, + "name": 0.2, +} + + +def normalize_content(content: str) -> str: + """Normalize content for consistent comparison.""" + return content.replace("\r\n", "\n").replace("\r", "\n").strip() + + +def _extract_type_name(chunk: dict) -> str: + """Extract normalized type name from chunk.""" + chunk_type = chunk.get("chunk_type") or chunk.get("concept") or chunk.get("type", "") + if isinstance(chunk_type, str): + return chunk_type.lower() + elif hasattr(chunk_type, "value"): + return str(chunk_type.value).lower() + elif hasattr(chunk_type, "name"): + return chunk_type.name.lower() + return str(chunk_type).lower() if chunk_type else "" + + +def compute_specificity_score(chunk: dict) -> float: + """Compute specificity score using weighted formula. + + score = w_type * type_weight + w_size * log(1 + line_count) + w_name * has_name + + Higher score = more specific, should be kept over lower-scoring duplicates. + """ + type_name = _extract_type_name(chunk) + type_weight = TYPE_WEIGHTS.get(type_name, 0.0) + + start_line = chunk.get("start_line", 0) + end_line = chunk.get("end_line", 0) + line_count = max(1, end_line - start_line + 1) + size_score = math.log(1 + line_count) / math.log(1000) + + has_name = 1.0 if chunk.get("name") or chunk.get("symbol") else 0.0 + + score = ( + SPECIFICITY_WEIGHTS["type"] * type_weight + + SPECIFICITY_WEIGHTS["size"] * min(1.0, size_score) + + SPECIFICITY_WEIGHTS["name"] * has_name + ) + + return score + + +def get_chunk_specificity(chunk: dict) -> int: + """Get integer specificity ranking (legacy interface, 0-4 scale).""" + type_name = _extract_type_name(chunk) + weight = TYPE_WEIGHTS.get(type_name, 0.0) + + if weight >= 0.9: + return 4 + elif weight >= 0.7: + return 3 + elif weight >= 0.5: + return 2 + elif weight >= 0.3: + return 1 + return 0 + + +def deduplicate_chunks( + chunks: Sequence[T], + language: str | None = None, + content_key: str = "code", +) -> list[T]: + """Deduplicate chunks using hash-based exact match + interval-based substring detection. + + Args: + chunks: List of chunk dictionaries + language: Optional language for language-specific exemptions + content_key: Key to extract content from chunks (default: "code") + + Returns: + Deduplicated list of chunks + """ + if not chunks: + return [] + + # Language exemptions: Vue and Haskell preserve duplicates + if language and language.lower() in ("vue", "vue_template", "haskell"): + return list(chunks) + + # Stage 1: Exact content deduplication via hash table (O(n)) + exact_deduplicated = _deduplicate_exact_content(chunks, content_key) + + # Stage 2: Substring detection via interval scan (O(n log n)) + final = _remove_substring_overlaps(exact_deduplicated, content_key) + + logger.debug( + f"Deduplication: {len(chunks)} -> {len(exact_deduplicated)} (exact) -> {len(final)} (substring)" + ) + + return final + + +def _deduplicate_exact_content(chunks: Sequence[T], content_key: str) -> list[T]: + """Remove chunks with identical normalized content, keeping highest specificity.""" + hash_to_chunks: dict[int, list[T]] = defaultdict(list) + + for chunk in chunks: + content = chunk.get(content_key, "") + if not content: + content = chunk.get("content", "") or chunk.get("text", "") + + normalized = normalize_content(content) + if not normalized: + continue + + content_hash = xxhash.xxh3_64(normalized.encode("utf-8")).intdigest() + hash_to_chunks[content_hash].append(chunk) + + result = [] + for chunk_list in hash_to_chunks.values(): + if len(chunk_list) == 1: + result.append(chunk_list[0]) + else: + best = max( + chunk_list, + key=lambda c: ( + get_chunk_specificity(c), + -(c.get("end_line", 0) - c.get("start_line", 0)), + ), + ) + result.append(best) + + return result + + +def _remove_substring_overlaps(chunks: Sequence[T], content_key: str) -> list[T]: + """Remove BLOCK chunks that are substrings of DEFINITION/STRUCTURE chunks.""" + definitions = [] + structures = [] + blocks = [] + other = [] + + for chunk in chunks: + specificity = get_chunk_specificity(chunk) + type_name = _extract_type_name(chunk) + if specificity == 1: # BLOCK-like + blocks.append(chunk) + elif specificity >= 2: # DEFINITION-like (includes type_alias, type) + definitions.append(chunk) + elif type_name == "structure": # STRUCTURE-like + structures.append(chunk) + else: + other.append(chunk) + + containers = definitions + structures + containers.sort(key=lambda c: c.get("start_line", 0)) + + final = other + containers + + for block in blocks: + block_content = normalize_content( + block.get(content_key, "") or block.get("content", "") or block.get("text", "") + ) + block_start = block.get("start_line", 0) + block_end = block.get("end_line", 0) + + is_substring = False + for container in _find_overlapping(containers, block_start, block_end): + def_content = normalize_content( + container.get(content_key, "") or container.get("content", "") or container.get("text", "") + ) + if block_content in def_content and len(block_content) < len(def_content): + is_substring = True + break + + if not is_substring: + final.append(block) + + return final + + +def _find_overlapping(sorted_chunks: list[T], query_start: int, query_end: int) -> list[T]: + """Find chunks whose line ranges overlap with [query_start, query_end].""" + overlapping = [] + for chunk in sorted_chunks: + chunk_start = chunk.get("start_line", 0) + chunk_end = chunk.get("end_line", 0) + + if chunk_end < query_start: + continue + if chunk_start > query_end: + break + + overlapping.append(chunk) + + return overlapping + + +def deduplicate_semantic_chunks( + chunks: Sequence, + language: str | None = None, +) -> list: + """Deduplicate SemanticChunk objects using O(n log n) algorithm. + + Converts SemanticChunk dataclass objects to dicts, deduplicates, + and returns the original objects. + + Args: + chunks: List of SemanticChunk objects (with content, start_line, end_line, concept) + language: Optional language for exemptions (Vue, Haskell) + + Returns: + Deduplicated list of SemanticChunk objects + """ + if not chunks: + return [] + + chunk_dicts = [] + for i, c in enumerate(chunks): + concept = getattr(c, "concept", None) + if concept is not None: + if hasattr(concept, "value"): + concept_str = concept.value + elif hasattr(concept, "name"): + concept_str = concept.name + else: + concept_str = str(concept) + else: + concept_str = "" + + chunk_dicts.append({ + "content": getattr(c, "content", ""), + "start_line": getattr(c, "start_line", 0), + "end_line": getattr(c, "end_line", 0), + "concept": concept_str, + "_idx": i, + }) + + deduped_dicts = deduplicate_chunks(chunk_dicts, language, content_key="content") + + kept_indices = {d["_idx"] for d in deduped_dicts} + return [c for i, c in enumerate(chunks) if i in kept_indices] diff --git a/scripts/ingest/file_discovery_cache.py b/scripts/ingest/file_discovery_cache.py new file mode 100644 index 00000000..dd850579 --- /dev/null +++ b/scripts/ingest/file_discovery_cache.py @@ -0,0 +1,198 @@ +"""Cache for file discovery operations to reduce filesystem overhead. + +Caches glob pattern matching results with TTL and directory mtime validation. +Useful for repeated directory scans during watch mode or incremental indexing. + +Usage: + cache = FileDiscoveryCache(max_entries=100, ttl_seconds=300) + + # Get files matching patterns (cached) + files = cache.get_files( + directory=Path("/project"), + patterns=["**/*.py", "**/*.js"], + exclude_patterns=["**/node_modules/**"] + ) + + # Invalidate when directory changes + cache.invalidate_directory(Path("/project/src")) +""" + +from __future__ import annotations + +import logging +import time +from collections import OrderedDict +from fnmatch import fnmatch +from pathlib import Path +from typing import List, Dict, Optional, Tuple, Any + +logger = logging.getLogger(__name__) + + +class FileDiscoveryCache: + """LRU cache for file discovery operations with TTL and mtime validation.""" + + def __init__(self, max_entries: int = 100, ttl_seconds: int = 300): + self.max_entries = max_entries + self.ttl_seconds = ttl_seconds + self._cache: OrderedDict[str, Tuple[List[Path], float, float]] = OrderedDict() + self._hits = 0 + self._misses = 0 + self._evictions = 0 + self._invalidations = 0 + + def get_files( + self, + directory: Path, + patterns: List[str], + exclude_patterns: Optional[List[str]] = None, + ) -> List[Path]: + """Get files matching patterns with caching.""" + cache_key = self._make_cache_key(directory, patterns, exclude_patterns) + + cached_result = self._get_from_cache(cache_key, directory) + if cached_result is not None: + self._hits += 1 + return cached_result + + self._misses += 1 + files = self._discover_files(directory, patterns, exclude_patterns) + self._store_in_cache(cache_key, files, directory) + return files + + def invalidate_directory(self, directory: Path) -> int: + """Invalidate all cache entries for a directory. Returns count removed.""" + dir_str = str(directory) + keys_to_remove = [key for key in self._cache if key.startswith(f"{dir_str}|")] + + for key in keys_to_remove: + del self._cache[key] + self._invalidations += 1 + + return len(keys_to_remove) + + def clear(self) -> None: + """Clear all cache entries.""" + count = len(self._cache) + self._cache.clear() + self._evictions += count + + def get_stats(self) -> Dict[str, Any]: + """Get cache statistics.""" + total_requests = self._hits + self._misses + hit_rate = (self._hits / total_requests * 100) if total_requests > 0 else 0.0 + + return { + "hits": self._hits, + "misses": self._misses, + "evictions": self._evictions, + "invalidations": self._invalidations, + "cache_size": len(self._cache), + "max_entries": self.max_entries, + "ttl_seconds": self.ttl_seconds, + "hit_rate_percent": round(hit_rate, 2), + } + + def _make_cache_key( + self, directory: Path, patterns: List[str], exclude_patterns: Optional[List[str]] + ) -> str: + """Create a cache key from directory and patterns.""" + patterns_str = "|".join(sorted(patterns)) + exclude_str = "|".join(sorted(exclude_patterns or [])) + return f"{directory}|{patterns_str}|{exclude_str}" + + def _get_from_cache(self, cache_key: str, directory: Path) -> Optional[List[Path]]: + """Get entry from cache if valid (TTL and mtime).""" + if cache_key not in self._cache: + return None + + files, timestamp, cached_mtime = self._cache[cache_key] + current_time = time.time() + + if current_time - timestamp > self.ttl_seconds: + del self._cache[cache_key] + self._evictions += 1 + return None + + try: + current_mtime = directory.stat().st_mtime + if current_mtime > cached_mtime: + del self._cache[cache_key] + self._invalidations += 1 + return None + except OSError: + del self._cache[cache_key] + self._invalidations += 1 + return None + + self._cache.move_to_end(cache_key) + return files + + def _store_in_cache(self, cache_key: str, files: List[Path], directory: Path) -> None: + """Store files in cache with mtime tracking.""" + try: + directory_mtime = directory.stat().st_mtime + except OSError: + return + + while len(self._cache) >= self.max_entries: + oldest_key = next(iter(self._cache)) + del self._cache[oldest_key] + self._evictions += 1 + + self._cache[cache_key] = (files, time.time(), directory_mtime) + + def _discover_files( + self, directory: Path, patterns: List[str], exclude_patterns: Optional[List[str]] + ) -> List[Path]: + """Perform actual file discovery via glob.""" + try: + files: List[Path] = [] + for pattern in patterns: + files.extend(directory.glob(pattern)) + + seen = set() + unique_files = [] + for file_path in files: + if file_path not in seen: + seen.add(file_path) + unique_files.append(file_path) + files = unique_files + + if exclude_patterns: + filtered_files = [] + for file_path in files: + try: + rel_path = file_path.relative_to(directory) + except ValueError: + rel_path = file_path + excluded = any( + fnmatch(str(rel_path), ep) or fnmatch(str(file_path), ep) + for ep in exclude_patterns + ) + if not excluded: + filtered_files.append(file_path) + files = filtered_files + + return files + + except Exception as e: + logger.debug(f"Failed to discover files in {directory}: {e}") + return [] + + +_default_cache: Optional[FileDiscoveryCache] = None + + +def get_default_file_discovery_cache() -> FileDiscoveryCache: + """Get or create the default global file discovery cache instance.""" + global _default_cache + if _default_cache is None: + _default_cache = FileDiscoveryCache() + return _default_cache + + +__all__ = [ + "FileDiscoveryCache", + "get_default_file_discovery_cache", +] diff --git a/scripts/ingest/language_mappings/base.py b/scripts/ingest/language_mappings/base.py index c68841cb..289e1d5a 100644 --- a/scripts/ingest/language_mappings/base.py +++ b/scripts/ingest/language_mappings/base.py @@ -158,42 +158,78 @@ def get_expression_preview(self, expr: str, max_length: int = 20) -> str: expr = expr[:max_length - 3] + "..." return expr if expr else "expr" - def find_child_by_type(self, node: Any, node_type: str) -> Optional[Any]: - """Find first child of specified type.""" - if not TREE_SITTER_AVAILABLE or node is None: - return None - for i in range(node.child_count): - child = node.child(i) - if child and child.type == node_type: - return child - return None + # ------------------------------------------------------------------------- + # Constant extraction (UPPER_SNAKE_CASE pattern) + # ------------------------------------------------------------------------- - def find_children_by_type(self, node: Any, node_type: str) -> List[Any]: - """Find all children of specified type.""" - if not TREE_SITTER_AVAILABLE or node is None: - return [] - return [node.child(i) for i in range(node.child_count) - if node.child(i) and node.child(i).type == node_type] + def extract_constants( + self, concept: ConceptType, captures: Dict[str, Any], content: bytes + ) -> Optional[List[Dict[str, Any]]]: + """Extract constants from definition captures. + + Override in language-specific mappings for custom constant detection. + Default implementation detects UPPER_SNAKE_CASE patterns. + + Returns: + List of {"name": str, "value": str} dicts, or None if not a constant + """ + import re + + if concept != ConceptType.DEFINITION: + return None + + name = self.extract_name(concept, captures, content) + if not name: + return None + + if not re.match(r"^_?[A-Z][A-Z0-9_]*$", name): + return None + + text = self.extract_content(concept, captures, content) + value = "" + + for pattern in [ + r"=\s*(.+?)(?:\n|$)", + r":\s*\w+\s*=\s*(.+?)(?:\n|$)", + ]: + match = re.search(pattern, text) + if match: + value = match.group(1).strip() + break + + if len(value) > MAX_CONSTANT_VALUE_LENGTH: + value = value[:MAX_CONSTANT_VALUE_LENGTH] + "..." + + return [{"name": name, "value": value}] - def get_node_line_range(self, node: Any) -> tuple: - """Get (start_line, end_line) 1-based.""" - if not TREE_SITTER_AVAILABLE or node is None: - return (1, 1) - return (node.start_point[0] + 1, node.end_point[0] + 1) + # ------------------------------------------------------------------------- + # Import resolution (override per-language) + # ------------------------------------------------------------------------- - def get_node_byte_range(self, node: Any) -> tuple: - """Get (start_byte, end_byte).""" - if not TREE_SITTER_AVAILABLE or node is None: - return (0, 0) - return (node.start_byte, node.end_byte) + def resolve_import_path( + self, import_text: str, base_dir: str, source_file: str + ) -> Optional[str]: + """Resolve import statement to actual file path. + + Override in language-specific mappings. Default returns None. + + Args: + import_text: The import statement text + base_dir: Base directory of the project + source_file: Path to the file containing the import + + Returns: + Resolved file path, or None if cannot resolve + """ + return None - def walk_tree(self, node: Any) -> Iterator[Any]: - """Walk all nodes depth-first.""" - if not TREE_SITTER_AVAILABLE or node is None: - return - yield node - for i in range(node.child_count): - child = node.child(i) - if child: - yield from self.walk_tree(child) + def get_import_module(self, import_text: str) -> Optional[str]: + """Extract module name from import statement. + + Override in language-specific mappings. + + Returns: + Module name, or None if cannot parse + """ + return None diff --git a/scripts/ingest/language_mappings/go.py b/scripts/ingest/language_mappings/go.py index 514952b6..c67cda7f 100644 --- a/scripts/ingest/language_mappings/go.py +++ b/scripts/ingest/language_mappings/go.py @@ -146,11 +146,9 @@ def get_query_for_concept(self, concept: ConceptType) -> str | None: elif concept == ConceptType.IMPORT: return """ - (import_declaration - (import_spec - path: (interpreted_string_literal) @import_path - ) @import_spec - ) @definition + (import_spec + path: (interpreted_string_literal) @import_path + ) @import (package_clause (package_identifier) @package_name diff --git a/scripts/ingest/language_mappings/javascript.py b/scripts/ingest/language_mappings/javascript.py index 882eca55..06297a4e 100644 --- a/scripts/ingest/language_mappings/javascript.py +++ b/scripts/ingest/language_mappings/javascript.py @@ -34,7 +34,7 @@ TSNode = None -class JavaScriptMapping(BaseMapping, JSFamilyExtraction): +class JavaScriptMapping(JSFamilyExtraction, BaseMapping): """JavaScript language mapping for tree-sitter parsing. Provides JavaScript-specific queries and extraction methods for: diff --git a/scripts/ingest/language_mappings/jsx.py b/scripts/ingest/language_mappings/jsx.py index bc96f155..bc0fe94e 100644 --- a/scripts/ingest/language_mappings/jsx.py +++ b/scripts/ingest/language_mappings/jsx.py @@ -46,7 +46,7 @@ def __init__(self): """Initialize JSX mapping.""" # Initialize with JSX language instead of JavaScript super().__init__() - self.language = Language.JSX + self.language = "jsx" def get_function_query(self) -> str: """Get tree-sitter query for JSX function definitions including React components. diff --git a/scripts/ingest/language_mappings/svelte.py b/scripts/ingest/language_mappings/svelte.py index dc59f63b..8ad0ead8 100644 --- a/scripts/ingest/language_mappings/svelte.py +++ b/scripts/ingest/language_mappings/svelte.py @@ -36,7 +36,7 @@ class SvelteMapping(TypeScriptMapping): def __init__(self) -> None: """Initialize Svelte mapping (delegates to TypeScript for script parsing).""" super().__init__() - self.language = Language.SVELTE # Override to SVELTE + self.language = "svelte" # Override to SVELTE # Section extraction patterns SCRIPT_PATTERN = re.compile( diff --git a/scripts/ingest/language_mappings/tsx.py b/scripts/ingest/language_mappings/tsx.py index 7bc1b4a5..d0fe11e7 100644 --- a/scripts/ingest/language_mappings/tsx.py +++ b/scripts/ingest/language_mappings/tsx.py @@ -43,7 +43,7 @@ class TSXMapping(TypeScriptMapping): def __init__(self): """Initialize TSX mapping.""" # Initialize with TSX language instead of TypeScript - BaseMapping.__init__(self, Language.TSX) + BaseMapping.__init__(self, "tsx") def get_function_query(self) -> str: """Get tree-sitter query for TSX function definitions including typed React components. diff --git a/scripts/ingest/language_mappings/typescript.py b/scripts/ingest/language_mappings/typescript.py index f17ffddf..7acd9920 100644 --- a/scripts/ingest/language_mappings/typescript.py +++ b/scripts/ingest/language_mappings/typescript.py @@ -35,7 +35,7 @@ # TSNode is already defined in TYPE_CHECKING block -class TypeScriptMapping(BaseMapping, JSFamilyExtraction): +class TypeScriptMapping(JSFamilyExtraction, BaseMapping): """TypeScript language mapping for tree-sitter parsing. This mapping handles TypeScript-specific AST patterns including: @@ -172,6 +172,12 @@ def get_query_for_concept(self, concept: "ConceptType") -> str | None: # type: return """ (comment) @definition """ + elif concept == ConceptType.IMPORT: + return """ + (import_statement + source: (string) @import_path + ) @import + """ return None # extract_name / extract_metadata / extract_content are inherited diff --git a/scripts/ingest/language_mappings/vue.py b/scripts/ingest/language_mappings/vue.py index 99c672ee..96085a54 100644 --- a/scripts/ingest/language_mappings/vue.py +++ b/scripts/ingest/language_mappings/vue.py @@ -52,7 +52,7 @@ class VueMapping(TypeScriptMapping): def __init__(self) -> None: """Initialize Vue mapping (delegates to TypeScript for script parsing).""" super().__init__() - self.language = Language.VUE # Override to VUE + self.language = "vue" # Override to VUE # Section extraction patterns SCRIPT_PATTERN = re.compile( diff --git a/scripts/ingest/models.py b/scripts/ingest/models.py new file mode 100644 index 00000000..00c9deb6 --- /dev/null +++ b/scripts/ingest/models.py @@ -0,0 +1,264 @@ +"""Domain models for code indexing. + +Frozen dataclasses for immutability and hashability. +All models validate their invariants in __post_init__. + +Usage: + chunk = Chunk( + id="abc123", + content="def foo(): pass", + start_line=1, + end_line=1, + file_path="/src/main.py", + language="python", + chunk_type=ChunkType.DEFINITION, + ) +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, FrozenSet + +from scripts.exceptions import ValidationError + + +class ChunkType(str, Enum): + """Universal chunk types for code chunking.""" + DEFINITION = "definition" + BLOCK = "block" + COMMENT = "comment" + IMPORT = "import" + STRUCTURE = "structure" + UNKNOWN = "unknown" + + +class SymbolKind(str, Enum): + """Symbol kinds for code analysis.""" + FUNCTION = "function" + METHOD = "method" + CLASS = "class" + INTERFACE = "interface" + STRUCT = "struct" + ENUM = "enum" + CONSTANT = "constant" + VARIABLE = "variable" + TYPE_ALIAS = "type_alias" + MODULE = "module" + NAMESPACE = "namespace" + PROPERTY = "property" + UNKNOWN = "unknown" + + +@dataclass(frozen=True) +class Position: + """A position in a source file.""" + line: int + column: int = 0 + byte_offset: Optional[int] = None + + def __post_init__(self) -> None: + if self.line < 0: + raise ValidationError("Line must be non-negative", field="line", value=self.line) + if self.column < 0: + raise ValidationError("Column must be non-negative", field="column", value=self.column) + + +@dataclass(frozen=True) +class Range: + """A range in a source file (start inclusive, end exclusive).""" + start: Position + end: Position + + def __post_init__(self) -> None: + if self.start.line > self.end.line: + raise ValidationError( + f"Start line ({self.start.line}) must be <= end line ({self.end.line})", + field="range", + ) + if self.start.line == self.end.line and self.start.column > self.end.column: + raise ValidationError( + f"Start column ({self.start.column}) must be <= end column ({self.end.column}) on same line", + field="range", + ) + + @property + def line_count(self) -> int: + return self.end.line - self.start.line + 1 + + +@dataclass(frozen=True) +class Symbol: + """A code symbol (function, class, method, etc.).""" + name: str + kind: SymbolKind + start_line: int + end_line: int + path: Optional[str] = None + signature: Optional[str] = None + docstring: Optional[str] = None + decorators: FrozenSet[str] = field(default_factory=frozenset) + parameters: FrozenSet[str] = field(default_factory=frozenset) + complexity: Optional[int] = None + + def __post_init__(self) -> None: + if not self.name: + raise ValidationError("Symbol name cannot be empty", field="name") + if self.start_line < 0: + raise ValidationError("start_line must be non-negative", field="start_line", value=self.start_line) + if self.end_line < self.start_line: + raise ValidationError( + f"end_line ({self.end_line}) must be >= start_line ({self.start_line})", + field="end_line", + ) + + @property + def line_count(self) -> int: + return self.end_line - self.start_line + 1 + + @property + def full_path(self) -> str: + return self.path or self.name + + +@dataclass(frozen=True) +class Chunk: + """A code chunk for indexing.""" + id: str + content: str + start_line: int + end_line: int + file_path: str + language: str + chunk_type: ChunkType = ChunkType.UNKNOWN + symbol: Optional[str] = None + symbol_path: Optional[str] = None + imports: FrozenSet[str] = field(default_factory=frozenset) + calls: FrozenSet[str] = field(default_factory=frozenset) + metadata: Dict[str, Any] = field(default_factory=dict, hash=False) + + def __post_init__(self) -> None: + if not self.id: + raise ValidationError("Chunk id cannot be empty", field="id") + if not self.content: + raise ValidationError("Chunk content cannot be empty", field="content") + if self.start_line < 0: + raise ValidationError("start_line must be non-negative", field="start_line", value=self.start_line) + if self.end_line < self.start_line: + raise ValidationError( + f"end_line ({self.end_line}) must be >= start_line ({self.start_line})", + field="end_line", + ) + + @property + def line_count(self) -> int: + return self.end_line - self.start_line + 1 + + +@dataclass(frozen=True) +class ImportRef: + """An import reference.""" + module: str + names: FrozenSet[str] = field(default_factory=frozenset) + is_from_import: bool = False + alias: Optional[str] = None + line: Optional[int] = None + + +@dataclass(frozen=True) +class CallRef: + """A function/method call reference.""" + callee: str + caller_symbol: Optional[str] = None + line: Optional[int] = None + resolved_path: Optional[str] = None + + +@dataclass(frozen=True) +class FileAnalysis: + """Complete analysis result for a file.""" + file_path: str + language: str + symbols: FrozenSet[Symbol] = field(default_factory=frozenset) + chunks: FrozenSet[Chunk] = field(default_factory=frozenset) + imports: FrozenSet[ImportRef] = field(default_factory=frozenset) + calls: FrozenSet[CallRef] = field(default_factory=frozenset) + file_hash: Optional[str] = None + line_count: int = 0 + parse_time_ms: Optional[float] = None + + +@dataclass +class IndexingResult: + """Result of indexing a file or batch of files.""" + files_processed: int = 0 + files_skipped: int = 0 + files_failed: int = 0 + chunks_indexed: int = 0 + symbols_indexed: int = 0 + errors: List[str] = field(default_factory=list) + duration_seconds: float = 0.0 + + @property + def success_rate(self) -> float: + total = self.files_processed + self.files_skipped + self.files_failed + if total == 0: + return 1.0 + return self.files_processed / total + + +def chunk_from_dict(data: Dict[str, Any]) -> Chunk: + """Create a Chunk from a dictionary (for interop with existing code).""" + return Chunk( + id=data.get("id", data.get("chunk_id", "")), + content=data.get("content", data.get("code", data.get("text", ""))), + start_line=data.get("start_line", data.get("start", 0)), + end_line=data.get("end_line", data.get("end", 0)), + file_path=data.get("file_path", data.get("path", "")), + language=data.get("language", "unknown"), + chunk_type=ChunkType(data.get("chunk_type", data.get("type", "unknown"))), + symbol=data.get("symbol", data.get("name")), + symbol_path=data.get("symbol_path"), + imports=frozenset(data.get("imports", [])), + calls=frozenset(data.get("calls", [])), + metadata=data.get("metadata", {}), + ) + + +def symbol_from_dict(data: Dict[str, Any]) -> Symbol: + """Create a Symbol from a dictionary (for interop with existing code).""" + kind_str = data.get("kind", data.get("type", "unknown")) + try: + kind = SymbolKind(kind_str.lower()) + except ValueError: + kind = SymbolKind.UNKNOWN + + return Symbol( + name=data.get("name", ""), + kind=kind, + start_line=data.get("start_line", data.get("start", 0)), + end_line=data.get("end_line", data.get("end", 0)), + path=data.get("path", data.get("symbol_path")), + signature=data.get("signature"), + docstring=data.get("docstring"), + decorators=frozenset(data.get("decorators", [])), + parameters=frozenset(data.get("parameters", [])), + complexity=data.get("complexity"), + ) + + +__all__ = [ + "ChunkType", + "SymbolKind", + "Position", + "Range", + "Symbol", + "Chunk", + "ImportRef", + "CallRef", + "FileAnalysis", + "IndexingResult", + "chunk_from_dict", + "symbol_from_dict", +] diff --git a/scripts/ingest/pipeline.py b/scripts/ingest/pipeline.py index 21d5ef3e..405dc30b 100644 --- a/scripts/ingest/pipeline.py +++ b/scripts/ingest/pipeline.py @@ -1161,106 +1161,106 @@ def make_point( ] upsert_points(client, collection, points) - # Emit graph edges for symbol relationships + # Emit graph edges for symbol relationships (always on - Qdrant flat graph) + # Neo4j takes precedence when NEO4J_GRAPH=1 is set # Always try symbol-level edges first, fall back to file-level if no symbol_calls try: - if os.environ.get("INDEX_GRAPH_EDGES", "1").lower() in {"1", "true", "yes", "on"}: - graph_coll = ensure_graph_collection(client, collection) - # Delete old edges for this file before upserting new ones - delete_edges_by_path(client, graph_coll, str(file_path), repo=repo_tag) - - all_edges = [] - if symbol_calls: - # Symbol-level edges: use AST-extracted caller→callee relationships - for caller, callees in symbol_calls.items(): - if not caller or not callees: - continue - start_line = None - end_line = None - sym_info = symbol_meta_by_path.get(caller) or symbol_meta_by_name.get(caller) - if sym_info is not None: - try: - start_line = int(getattr(sym_info, "start_line", 0) or 0) - end_line = int(getattr(sym_info, "end_line", 0) or 0) - except Exception: - start_line = None - end_line = None - # Get caller_point_id from symbol_path mapping - caller_pid = symbol_path_to_point_id.get(caller) - all_edges.extend( - _extract_call_edges_compat( - symbol_path=caller, - calls=callees, - path=str(file_path), - repo=repo_tag, - start_line=start_line, - end_line=end_line, - language=language, - caller_point_id=caller_pid, - import_paths=import_map, - collection=collection, - qdrant_client=client, - ) - ) - if imports: - # For file-level imports, use first point ID if available - file_pid = next(iter(symbol_path_to_point_id.values()), None) if symbol_path_to_point_id else None - all_edges.extend( - _extract_import_edges_compat( - symbol_path=str(file_path), - imports=imports, - path=str(file_path), - repo=repo_tag, - language=language, - caller_point_id=file_pid, - collection=collection, - qdrant_client=client, - ) - ) - else: - # File-level fallback: emit file→symbol edges - source_file_path = str(file_path) - # Use first point ID for file-level edges - file_pid = next(iter(symbol_path_to_point_id.values()), None) if symbol_path_to_point_id else None - if calls: - all_edges.extend(_extract_call_edges_compat( - symbol_path=source_file_path, - calls=calls, - path=source_file_path, + graph_coll = ensure_graph_collection(client, collection) + # Delete old edges for this file before upserting new ones + delete_edges_by_path(client, graph_coll, str(file_path), repo=repo_tag) + + all_edges = [] + if symbol_calls: + # Symbol-level edges: use AST-extracted caller→callee relationships + for caller, callees in symbol_calls.items(): + if not caller or not callees: + continue + start_line = None + end_line = None + sym_info = symbol_meta_by_path.get(caller) or symbol_meta_by_name.get(caller) + if sym_info is not None: + try: + start_line = int(getattr(sym_info, "start_line", 0) or 0) + end_line = int(getattr(sym_info, "end_line", 0) or 0) + except Exception: + start_line = None + end_line = None + # Get caller_point_id from symbol_path mapping + caller_pid = symbol_path_to_point_id.get(caller) + all_edges.extend( + _extract_call_edges_compat( + symbol_path=caller, + calls=callees, + path=str(file_path), repo=repo_tag, - caller_point_id=file_pid, + start_line=start_line, + end_line=end_line, + language=language, + caller_point_id=caller_pid, import_paths=import_map, collection=collection, qdrant_client=client, - )) - if imports: - all_edges.extend(_extract_import_edges_compat( - symbol_path=source_file_path, + ) + ) + if imports: + # For file-level imports, use first point ID if available + file_pid = next(iter(symbol_path_to_point_id.values()), None) if symbol_path_to_point_id else None + all_edges.extend( + _extract_import_edges_compat( + symbol_path=str(file_path), imports=imports, - path=source_file_path, + path=str(file_path), repo=repo_tag, + language=language, caller_point_id=file_pid, collection=collection, qdrant_client=client, + ) + ) + else: + # File-level fallback: emit file→symbol edges + source_file_path = str(file_path) + # Use first point ID for file-level edges + file_pid = next(iter(symbol_path_to_point_id.values()), None) if symbol_path_to_point_id else None + if calls: + all_edges.extend(_extract_call_edges_compat( + symbol_path=source_file_path, + calls=calls, + path=source_file_path, + repo=repo_tag, + caller_point_id=file_pid, + import_paths=import_map, + collection=collection, + qdrant_client=client, + )) + if imports: + all_edges.extend(_extract_import_edges_compat( + symbol_path=source_file_path, + imports=imports, + path=source_file_path, + repo=repo_tag, + caller_point_id=file_pid, + collection=collection, + qdrant_client=client, + )) + + # Extract inheritance edges (INHERITS_FROM) for all classes + if inheritance_map: + for class_name, base_classes in inheritance_map.items(): + if class_name and base_classes: + all_edges.extend(extract_inheritance_edges( + class_name=class_name, + base_classes=base_classes, + path=str(file_path), + repo=repo_tag, + language=language, + import_paths=import_map, + collection=collection, + qdrant_client=client, )) - # Extract inheritance edges (INHERITS_FROM) for all classes - if inheritance_map: - for class_name, base_classes in inheritance_map.items(): - if class_name and base_classes: - all_edges.extend(extract_inheritance_edges( - class_name=class_name, - base_classes=base_classes, - path=str(file_path), - repo=repo_tag, - language=language, - import_paths=import_map, - collection=collection, - qdrant_client=client, - )) - - if all_edges: - upsert_edges(client, graph_coll, all_edges) + if all_edges: + upsert_edges(client, graph_coll, all_edges) except Exception as e: # Don't fail indexing if graph edges fail logger.warning(f"Failed to emit graph edges for {file_path}: {e}") @@ -1290,6 +1290,16 @@ def index_repo( schema_mode: str | None = None, ): """Index a repository into Qdrant.""" + # CRITICAL OPTIMIZATION: When recreating collection, skip all cache checks and deduplication + # The collection is empty, so: + # - skip_unchanged=False: no point checking if file hash changed (nothing in DB) + # - dedupe=False: no point deleting existing points (collection is empty) + # This avoids 2 Qdrant calls per file (scroll + delete) that are wasteful for fresh collections + if recreate: + skip_unchanged = False + dedupe = False + logger.info("[index_repo] Recreate mode: skipping cache checks and deduplication (collection is fresh)") + fast_fs = _env_truthy(os.environ.get("INDEX_FS_FASTPATH"), False) if skip_unchanged and not recreate and fast_fs and get_cached_file_meta is not None: try: @@ -2162,99 +2172,99 @@ def process_file_with_smart_reindexing( if all_points: _upsert_points_fn(client, current_collection, all_points) - # Emit graph edges for symbol relationships + # Emit graph edges for symbol relationships (always on - Qdrant flat graph) + # Neo4j takes precedence when NEO4J_GRAPH=1 is set # Always try symbol-level edges first, fall back to file-level if no symbol_calls try: - if os.environ.get("INDEX_GRAPH_EDGES", "1").lower() in {"1", "true", "yes", "on"}: - graph_coll = ensure_graph_collection(client, current_collection) - delete_edges_by_path(client, graph_coll, fp, repo=per_file_repo) - - all_edges = [] - if symbol_calls: - # Symbol-level edges: use AST-extracted caller→callee relationships - for caller, callees in symbol_calls.items(): - if not caller or not callees: - continue - start_line = None - sym_info = symbol_meta_by_path.get(caller) or symbol_meta_by_name.get(caller) - if sym_info is not None: - try: - start_line = int(getattr(sym_info, "start_line", 0) or 0) - except Exception: - start_line = None - # Get caller_point_id from symbol_path mapping - caller_pid = symbol_path_to_point_id_sr.get(caller) - all_edges.extend( - _extract_call_edges_compat( - symbol_path=caller, - calls=callees, - path=fp, - repo=per_file_repo, - start_line=start_line, - language=language, - caller_point_id=caller_pid, - import_paths=import_map, - ) - ) - if imports: - # For file-level imports, use first point ID if available - file_pid = next(iter(symbol_path_to_point_id_sr.values()), None) if symbol_path_to_point_id_sr else None - all_edges.extend( - _extract_import_edges_compat( - symbol_path=fp, - imports=imports, - path=fp, - repo=per_file_repo, - language=language, - caller_point_id=file_pid, - ) + graph_coll = ensure_graph_collection(client, current_collection) + delete_edges_by_path(client, graph_coll, fp, repo=per_file_repo) + + all_edges = [] + if symbol_calls: + # Symbol-level edges: use AST-extracted caller→callee relationships + for caller, callees in symbol_calls.items(): + if not caller or not callees: + continue + start_line = None + sym_info = symbol_meta_by_path.get(caller) or symbol_meta_by_name.get(caller) + if sym_info is not None: + try: + start_line = int(getattr(sym_info, "start_line", 0) or 0) + except Exception: + start_line = None + # Get caller_point_id from symbol_path mapping + caller_pid = symbol_path_to_point_id_sr.get(caller) + all_edges.extend( + _extract_call_edges_compat( + symbol_path=caller, + calls=callees, + path=fp, + repo=per_file_repo, + start_line=start_line, + language=language, + caller_point_id=caller_pid, + import_paths=import_map, ) - else: - # File-level fallback: emit file→symbol edges - meta0 = {} - try: - if all_points and hasattr(all_points[0], "payload"): - meta0 = all_points[0].payload.get("metadata", {}) or {} - except Exception: - meta0 = {} - file_calls = meta0.get("calls", []) or [] - file_imports = meta0.get("imports", []) or [] - file_import_map = meta0.get("import_map", {}) or {} - # Use first point ID for file-level edges + ) + if imports: + # For file-level imports, use first point ID if available file_pid = next(iter(symbol_path_to_point_id_sr.values()), None) if symbol_path_to_point_id_sr else None - if file_calls: - all_edges.extend(_extract_call_edges_compat( + all_edges.extend( + _extract_import_edges_compat( symbol_path=fp, - calls=file_calls, + imports=imports, path=fp, repo=per_file_repo, + language=language, caller_point_id=file_pid, - import_paths=file_import_map, - )) - if file_imports: - all_edges.extend(_extract_import_edges_compat( - symbol_path=fp, - imports=file_imports, + ) + ) + else: + # File-level fallback: emit file→symbol edges + meta0 = {} + try: + if all_points and hasattr(all_points[0], "payload"): + meta0 = all_points[0].payload.get("metadata", {}) or {} + except Exception: + meta0 = {} + file_calls = meta0.get("calls", []) or [] + file_imports = meta0.get("imports", []) or [] + file_import_map = meta0.get("import_map", {}) or {} + # Use first point ID for file-level edges + file_pid = next(iter(symbol_path_to_point_id_sr.values()), None) if symbol_path_to_point_id_sr else None + if file_calls: + all_edges.extend(_extract_call_edges_compat( + symbol_path=fp, + calls=file_calls, + path=fp, + repo=per_file_repo, + caller_point_id=file_pid, + import_paths=file_import_map, + )) + if file_imports: + all_edges.extend(_extract_import_edges_compat( + symbol_path=fp, + imports=file_imports, + path=fp, + repo=per_file_repo, + caller_point_id=file_pid, + )) + + # Extract inheritance edges (INHERITS_FROM) for all classes + if inheritance_map: + for class_name, base_classes in inheritance_map.items(): + if class_name and base_classes: + all_edges.extend(extract_inheritance_edges( + class_name=class_name, + base_classes=base_classes, path=fp, repo=per_file_repo, - caller_point_id=file_pid, + language=language, + import_paths=import_map, )) - # Extract inheritance edges (INHERITS_FROM) for all classes - if inheritance_map: - for class_name, base_classes in inheritance_map.items(): - if class_name and base_classes: - all_edges.extend(extract_inheritance_edges( - class_name=class_name, - base_classes=base_classes, - path=fp, - repo=per_file_repo, - language=language, - import_paths=import_map, - )) - - if all_edges: - upsert_edges(client, graph_coll, all_edges) + if all_edges: + upsert_edges(client, graph_coll, all_edges) except Exception as e: logger.warning(f"Failed to emit graph edges for {fp}: {e}") diff --git a/scripts/ingest/qdrant.py b/scripts/ingest/qdrant.py index f3267082..7711988e 100644 --- a/scripts/ingest/qdrant.py +++ b/scripts/ingest/qdrant.py @@ -466,8 +466,8 @@ def ensure_collection( ) print(f"[COLLECTION_SUCCESS] Successfully updated collection {name} with missing vectors") except Exception as update_e: - print( - f"[COLLECTION_WARNING] Cannot add missing vectors to {name} ({update_e}). " + logger.debug( + f"Cannot add missing vectors to {name} ({update_e}). " "Continuing without them for this run." ) except Exception as e: diff --git a/scripts/ingest/search_chunker.py b/scripts/ingest/search_chunker.py index f56b2f91..bb287087 100644 --- a/scripts/ingest/search_chunker.py +++ b/scripts/ingest/search_chunker.py @@ -190,7 +190,7 @@ def chunk(self, content: str, language: str) -> List[ChunkResult]: return [self._content_to_result(content, 1, len(content.splitlines()))] if self.config.deduplicate: - chunks = self._deduplicate(chunks) + chunks = self._deduplicate_v2(chunks, language) chunks = self._split_oversized(chunks, content) chunks = self._merge_compatible(chunks, content) @@ -322,7 +322,7 @@ def _classify_concept(self, text: str, kind: Optional[str]) -> ConceptType: return ConceptType.BLOCK # Default def _deduplicate(self, chunks: List[SemanticChunk]) -> List[SemanticChunk]: - """Remove chunks with identical content.""" + """Remove chunks with identical content (legacy, hash-based).""" result = [] for chunk in chunks: if chunk.content_hash not in self._seen_hashes: @@ -330,6 +330,14 @@ def _deduplicate(self, chunks: List[SemanticChunk]) -> List[SemanticChunk]: result.append(chunk) return result + def _deduplicate_v2(self, chunks: List[SemanticChunk], language: str) -> List[SemanticChunk]: + """Remove chunks using O(n log n) deduplication with substring detection.""" + try: + from scripts.ingest.chunk_deduplication import deduplicate_semantic_chunks + return deduplicate_semantic_chunks(chunks, language) + except ImportError: + return self._deduplicate(chunks) + def _split_oversized(self, chunks: List[SemanticChunk], content: str) -> List[SemanticChunk]: """Split chunks that exceed size limits.""" result = [] diff --git a/scripts/ingest/tree_cache.py b/scripts/ingest/tree_cache.py new file mode 100644 index 00000000..3f834279 --- /dev/null +++ b/scripts/ingest/tree_cache.py @@ -0,0 +1,241 @@ +"""LRU cache for parsed syntax trees with automatic invalidation. + +Provides significant performance improvement by caching parsed ASTs +and validating freshness via file mtime/size. Thread-safe for concurrent access. + +Usage: + cache = TreeCache(max_entries=1000) + + # Try to get cached tree + tree = cache.get(file_path) + if tree is None: + tree = parser.parse(content) + cache.put(file_path, tree) + + # Get statistics + stats = cache.get_stats() +""" + +from __future__ import annotations + +import logging +import time +from collections import OrderedDict +from pathlib import Path +from threading import RLock +from typing import Any, Optional, Dict + +logger = logging.getLogger(__name__) + + +class TreeCacheEntry: + """Represents a cached syntax tree with metadata.""" + + __slots__ = ("tree", "file_path", "mtime", "size", "access_time", "hit_count") + + def __init__(self, tree: Any, file_path: Path, mtime: float, size: int): + self.tree = tree + self.file_path = file_path + self.mtime = mtime + self.size = size + self.access_time = time.time() + self.hit_count = 0 + + def is_valid(self) -> bool: + """Check if cache entry is still valid based on file mtime and size.""" + try: + stat = self.file_path.stat() + return stat.st_mtime == self.mtime and stat.st_size == self.size + except (OSError, FileNotFoundError): + return False + + def touch(self) -> None: + """Update access time and increment hit count.""" + self.access_time = time.time() + self.hit_count += 1 + + +class TreeCache: + """LRU cache for parsed syntax trees with automatic invalidation.""" + + def __init__(self, max_entries: int = 1000, max_memory_mb: int = 500): + self.max_entries = max_entries + self.max_memory_bytes = max_memory_mb * 1024 * 1024 + self._cache: OrderedDict[str, TreeCacheEntry] = OrderedDict() + self._lock = RLock() + + self._hits = 0 + self._misses = 0 + self._evictions = 0 + self._invalidations = 0 + + def get(self, file_path: Path) -> Optional[Any]: + """Get cached syntax tree for file, returning None if not cached or stale.""" + cache_key = str(file_path.resolve()) + + with self._lock: + if cache_key not in self._cache: + self._misses += 1 + return None + + entry = self._cache[cache_key] + + if not entry.is_valid(): + del self._cache[cache_key] + self._invalidations += 1 + self._misses += 1 + return None + + self._cache.move_to_end(cache_key) + entry.touch() + self._hits += 1 + return entry.tree + + def get_for_comparison(self, file_path: Path) -> Optional[Any]: + """Get cached tree even if stale - useful for incremental parsing comparison.""" + cache_key = str(file_path.resolve()) + + with self._lock: + if cache_key not in self._cache: + self._misses += 1 + return None + return self._cache[cache_key].tree + + def put(self, file_path: Path, tree: Any) -> None: + """Cache a parsed syntax tree.""" + if tree is None: + return + + cache_key = str(file_path.resolve()) + + try: + stat = file_path.stat() + mtime = stat.st_mtime + size = stat.st_size + except (OSError, FileNotFoundError): + return + + with self._lock: + entry = TreeCacheEntry(tree, file_path, mtime, size) + + if cache_key in self._cache: + del self._cache[cache_key] + + self._cache[cache_key] = entry + self._enforce_limits() + + def invalidate(self, file_path: Path) -> bool: + """Invalidate cached entry for a file. Returns True if found and removed.""" + cache_key = str(file_path.resolve()) + + with self._lock: + if cache_key in self._cache: + del self._cache[cache_key] + self._invalidations += 1 + return True + return False + + def clear(self) -> None: + """Clear all cached entries.""" + with self._lock: + self._cache.clear() + + def _enforce_limits(self) -> None: + """Enforce cache size and memory limits using LRU eviction.""" + while len(self._cache) > self.max_entries: + self._evict_lru() + + estimated_memory = sum(entry.size for entry in self._cache.values()) + while estimated_memory > self.max_memory_bytes and self._cache: + evicted_entry = self._evict_lru() + if evicted_entry: + estimated_memory -= evicted_entry.size + + def _evict_lru(self) -> Optional[TreeCacheEntry]: + """Evict least recently used entry.""" + if not self._cache: + return None + _, entry = self._cache.popitem(last=False) + self._evictions += 1 + return entry + + def get_stats(self) -> Dict[str, Any]: + """Get cache statistics.""" + with self._lock: + total_requests = self._hits + self._misses + hit_rate = (self._hits / total_requests * 100) if total_requests > 0 else 0.0 + + return { + "entries": len(self._cache), + "max_entries": self.max_entries, + "hits": self._hits, + "misses": self._misses, + "hit_rate_percent": round(hit_rate, 2), + "evictions": self._evictions, + "invalidations": self._invalidations, + "total_requests": total_requests, + "estimated_memory_mb": round( + sum(entry.size for entry in self._cache.values()) / 1024 / 1024, 2 + ), + "max_memory_mb": round(self.max_memory_bytes / 1024 / 1024, 2), + } + + def cleanup_stale_entries(self) -> int: + """Remove all stale entries. Returns number removed.""" + stale_keys = [] + + with self._lock: + for cache_key, entry in self._cache.items(): + if not entry.is_valid(): + stale_keys.append(cache_key) + + for key in stale_keys: + del self._cache[key] + self._invalidations += 1 + + return len(stale_keys) + + def get_cache_info(self, file_path: Path) -> Optional[Dict[str, Any]]: + """Get detailed information about a cached entry.""" + cache_key = str(file_path.resolve()) + + with self._lock: + if cache_key not in self._cache: + return None + + entry = self._cache[cache_key] + return { + "file_path": str(entry.file_path), + "cached_mtime": entry.mtime, + "cached_size": entry.size, + "access_time": entry.access_time, + "hit_count": entry.hit_count, + "is_valid": entry.is_valid(), + "age_seconds": time.time() - entry.access_time, + } + + +_default_cache: Optional[TreeCache] = None + + +def get_default_cache() -> TreeCache: + """Get or create the default global tree cache instance.""" + global _default_cache + if _default_cache is None: + _default_cache = TreeCache() + return _default_cache + + +def configure_default_cache(max_entries: int = 1000, max_memory_mb: int = 500) -> TreeCache: + """Configure the default global tree cache.""" + global _default_cache + _default_cache = TreeCache(max_entries, max_memory_mb) + return _default_cache + + +__all__ = [ + "TreeCache", + "TreeCacheEntry", + "get_default_cache", + "configure_default_cache", +] diff --git a/scripts/ingest/types.py b/scripts/ingest/types.py new file mode 100644 index 00000000..e2e6d53b --- /dev/null +++ b/scripts/ingest/types.py @@ -0,0 +1,73 @@ +"""Type aliases for semantic type safety. + +Using NewType creates distinct types that catch bugs at type-check time +while having zero runtime overhead. + +Example: + def get_chunk(chunk_id: ChunkId) -> Chunk: + ... + + # Type checker catches this mistake: + file_id: FileId = FileId(123) + get_chunk(file_id) # Error: FileId is not ChunkId +""" + +from __future__ import annotations + +from typing import NewType, TypeVar, Dict, Any, List, Optional, Union + +ChunkId = NewType("ChunkId", str) +FileId = NewType("FileId", str) +PointId = NewType("PointId", str) + +LineNumber = NewType("LineNumber", int) +ByteOffset = NewType("ByteOffset", int) +ColumnNumber = NewType("ColumnNumber", int) + +TokenCount = NewType("TokenCount", int) +CharCount = NewType("CharCount", int) + +Score = NewType("Score", float) +Embedding = NewType("Embedding", List[float]) + +FilePath = NewType("FilePath", str) +RepoName = NewType("RepoName", str) +CollectionName = NewType("CollectionName", str) + +Language = NewType("Language", str) +SymbolPath = NewType("SymbolPath", str) +SymbolName = NewType("SymbolName", str) + +FileHash = NewType("FileHash", str) +ContentHash = NewType("ContentHash", str) + +Payload = Dict[str, Any] +Metadata = Dict[str, Any] + +T = TypeVar("T") +ChunkT = TypeVar("ChunkT", bound="Chunk") + +__all__ = [ + "ChunkId", + "FileId", + "PointId", + "LineNumber", + "ByteOffset", + "ColumnNumber", + "TokenCount", + "CharCount", + "Score", + "Embedding", + "FilePath", + "RepoName", + "CollectionName", + "Language", + "SymbolPath", + "SymbolName", + "FileHash", + "ContentHash", + "Payload", + "Metadata", + "T", + "ChunkT", +] diff --git a/scripts/mcp_impl/symbol_graph.py b/scripts/mcp_impl/symbol_graph.py index 8ad38a77..82436e06 100644 --- a/scripts/mcp_impl/symbol_graph.py +++ b/scripts/mcp_impl/symbol_graph.py @@ -91,11 +91,19 @@ def clear_graph_collection_cache() -> None: def _get_graph_backend(): - """Return Neo4j graph backend when enabled, otherwise None.""" + """Return graph backend (Neo4j or Qdrant). + + Both backends are now supported through the unified GraphBackend interface: + - NEO4J_GRAPH=1: Uses Neo4j backend (takes precedence) + - Otherwise: Uses QdrantGraphBackend (default, always on) + + Returns None only on error, never for Qdrant-as-default case. + """ try: from scripts.graph_backends import get_graph_backend backend = get_graph_backend() - if backend.backend_type == "neo4j": + # Return any valid backend (neo4j or qdrant) + if backend is not None: return backend except Exception as e: logger.debug(f"Suppressed exception: {e} - graph backend lookup") @@ -1371,18 +1379,10 @@ async def graph_query_fn(**kwargs): results = [] used_graph = True - # Fallback for callees: use _query_callees which can use metadata.calls array - if query_type == "callees" and not results and not used_graph and not graph_backend: - results = await _query_callees( - client=client, - collection=coll, - symbol=symbol, - limit=limit, - language=language, - repo=repo, - ) # Fall back to legacy array field query if graph is unavailable or we opted to fallback on empty. - elif not results and not used_graph: + # Both Qdrant and Neo4j backends are now supported, so graph_backend should always be set. + # This fallback is for when graph returns empty or when graph backend fails to initialize. + if not results and not used_graph: if query_type == "callers": # Find chunks where metadata.calls array contains the symbol (exact match) results = await _query_array_field( @@ -1407,6 +1407,16 @@ async def graph_query_fn(**kwargs): under=_norm_under(under), repo=repo, ) + elif query_type == "callees": + # Find callees using metadata.calls array lookup + results = await _query_callees( + client=client, + collection=coll, + symbol=symbol, + limit=limit, + language=language, + repo=repo, + ) elif query_type == "definition": results = await _query_definition( client=client, diff --git a/scripts/upload_service.py b/scripts/upload_service.py index de774366..3c718ffd 100644 --- a/scripts/upload_service.py +++ b/scripts/upload_service.py @@ -467,14 +467,28 @@ def validate_bundle_format(bundle_path: Path) -> Dict[str, Any]: if not any(req_file in member for member in members): raise ValueError(f"Missing required file: {req_file}") - # Extract and validate manifest + # Extract and validate manifest - look for root-level manifest.json only + # The bundle structure is {bundle_id}/manifest.json at the root manifest_member = None + manifest_candidates = [m for m in members if m.endswith("manifest.json")] + logger.debug(f"[upload_service] Bundle members: {members[:20]}...") + logger.debug(f"[upload_service] Manifest candidates: {manifest_candidates}") + + # Prefer root-level manifest (exactly one path component before manifest.json) for member in members: - if member.endswith("manifest.json"): + if member.endswith("/manifest.json") and member.count("/") == 1: manifest_member = member break + + # Fallback: if no root-level manifest, try any manifest.json (but NOT in files/ subdirs) + if not manifest_member: + for member in members: + if member.endswith("manifest.json") and "/files/" not in member: + manifest_member = member + break if not manifest_member: + logger.error(f"[upload_service] No valid manifest.json found. Candidates were: {manifest_candidates}") raise ValueError("manifest.json not found in bundle") manifest_file = tar.extractfile(manifest_member) @@ -482,11 +496,13 @@ def validate_bundle_format(bundle_path: Path) -> Dict[str, Any]: raise ValueError("Cannot extract manifest.json") manifest = json.loads(manifest_file.read().decode('utf-8')) + logger.debug(f"[upload_service] Parsed manifest keys: {list(manifest.keys())}") # Validate manifest structure required_fields = ["version", "bundle_id", "workspace_path", "created_at", "sequence_number"] for field in required_fields: if field not in manifest: + logger.error(f"[upload_service] Manifest missing field '{field}'. Got keys: {list(manifest.keys())}") raise ValueError(f"Missing required field in manifest: {field}") return manifest diff --git a/scripts/workspace_state.py b/scripts/workspace_state.py index aec35330..ea63506a 100644 --- a/scripts/workspace_state.py +++ b/scripts/workspace_state.py @@ -79,6 +79,7 @@ def _get_redis_client(): return _REDIS_CLIENT try: import redis # type: ignore + from redis.connection import ConnectionPool except Exception as e: logger.warning(f"Redis backend enabled but redis package not available: {e}") return None @@ -86,21 +87,26 @@ def _get_redis_client(): try: socket_timeout = float(os.environ.get("CODEBASE_STATE_REDIS_SOCKET_TIMEOUT", "2") or 2) connect_timeout = float(os.environ.get("CODEBASE_STATE_REDIS_CONNECT_TIMEOUT", "2") or 2) + max_connections = int(os.environ.get("CODEBASE_STATE_REDIS_MAX_CONNECTIONS", "10") or 10) except Exception: socket_timeout = 2.0 connect_timeout = 2.0 + max_connections = 10 try: client = redis.Redis.from_url( url, decode_responses=True, socket_timeout=socket_timeout, socket_connect_timeout=connect_timeout, + max_connections=max_connections, + retry_on_timeout=True, ) try: client.ping() except Exception as e: logger.warning(f"Redis backend enabled but ping failed: {e}") return None + logger.info(f"Redis client initialized (max_connections={max_connections})") _REDIS_CLIENT = client return _REDIS_CLIENT except Exception as e: @@ -108,13 +114,31 @@ def _get_redis_client(): return None +def _redis_retry(fn, retries: int = 2, delay: float = 0.1): + """Retry a Redis operation on transient failures.""" + last_err = None + for attempt in range(retries + 1): + try: + return fn() + except Exception as e: + last_err = e + err_str = str(e).lower() + # Retry on timeout/connection errors, not on logic errors + if any(x in err_str for x in ("timeout", "connection", "reset", "broken pipe")): + if attempt < retries: + time.sleep(delay * (attempt + 1)) + continue + raise + raise last_err # type: ignore + + def _redis_get_json(kind: str, path: Path) -> Optional[Dict[str, Any]]: client = _get_redis_client() if client is None: return None key = _redis_key_for_path(kind, path) try: - raw = client.get(key) + raw = _redis_retry(lambda: client.get(key)) except Exception as e: logger.debug(f"Redis get failed for {key}: {e}") return None @@ -141,7 +165,7 @@ def _redis_set_json(kind: str, path: Path, obj: Dict[str, Any]) -> bool: logger.debug(f"Failed to JSON serialize redis payload for {key}: {e}") return False try: - client.set(key, payload) + _redis_retry(lambda: client.set(key, payload)) return True except Exception as e: logger.debug(f"Redis set failed for {key}: {e}") @@ -154,7 +178,7 @@ def _redis_exists(kind: str, path: Path) -> bool: return False key = _redis_key_for_path(kind, path) try: - return bool(client.exists(key)) + return bool(_redis_retry(lambda: client.exists(key))) except Exception as e: logger.debug(f"Redis exists failed for {key}: {e}") return False @@ -179,7 +203,7 @@ def _redis_get_json_by_key(key: str) -> Optional[Dict[str, Any]]: if client is None: return None try: - raw = client.get(key) + raw = _redis_retry(lambda: client.get(key)) except Exception as e: logger.debug(f"Redis get failed for {key}: {e}") return None @@ -201,7 +225,7 @@ def _redis_delete(kind: str, path: Path) -> bool: return False key = _redis_key_for_path(kind, path) try: - client.delete(key) + _redis_retry(lambda: client.delete(key)) return True except Exception as e: logger.debug(f"Redis delete failed for {key}: {e}") @@ -226,19 +250,22 @@ def _redis_lock(kind: str, path: Path): wait_ms = 2000 deadline = time.time() + (wait_ms / 1000.0) acquired = False + attempts = 0 while time.time() < deadline: + attempts += 1 try: if client.set(lock_key, token, nx=True, px=ttl_ms): acquired = True break except Exception as e: - logger.debug(f"Redis lock set failed for {lock_key}: {e}") + logger.warning(f"Redis lock set failed for {lock_key}: {e}") break time.sleep(0.05) if not acquired: - logger.debug(f"Redis lock not acquired for {lock_key}, proceeding without lock") + logger.info(f"Redis lock not acquired for {lock_key} after {attempts} attempts, proceeding without lock") yield return + logger.info(f"Redis lock acquired for {lock_key} (attempts={attempts}, ttl={ttl_ms}ms)") try: yield finally: @@ -249,8 +276,9 @@ def _redis_lock(kind: str, path: Path): lock_key, token, ) + logger.debug(f"Redis lock released for {lock_key}") except Exception as e: - logger.debug(f"Redis lock release failed for {lock_key}: {e}") + logger.warning(f"Redis lock release failed for {lock_key}: {e}") def is_staging_enabled() -> bool: @@ -2254,7 +2282,7 @@ def get_indexing_config_snapshot() -> Dict[str, Any]: "index_use_enhanced_ast": _env_truthy("INDEX_USE_ENHANCED_AST", False), "mini_vec_dim": _env_int("MINI_VEC_DIM"), "lex_sparse_mode": _env_truthy("LEX_SPARSE_MODE", False), - "index_graph_edges": _env_truthy("INDEX_GRAPH_EDGES", True), + "index_graph_edges": True, # Always on - Qdrant flat graph is unconditional } diff --git a/tests/test_ast_analyzer_mappings.py b/tests/test_ast_analyzer_mappings.py new file mode 100644 index 00000000..88017ec2 --- /dev/null +++ b/tests/test_ast_analyzer_mappings.py @@ -0,0 +1,533 @@ +#!/usr/bin/env python3 +""" +Comprehensive tests for ast_analyzer language mappings integration. + +Tests that: +1. All 32 language mappings can be instantiated +2. ast_analyzer correctly uses mappings for symbol extraction +3. Import extraction works across languages +4. Call extraction works +5. Fallback to legacy analyzers works when needed +""" + +import pytest +from scripts.ast_analyzer import ( + get_ast_analyzer, + ASTAnalyzer, + CodeSymbol, + ConceptUnit, + ImportReference, + CallReference, + _LANGUAGE_MAPPINGS_AVAILABLE, + _TS_AVAILABLE, +) +from scripts.ingest.language_mappings import _MAPPINGS, get_mapping, ConceptType + + +# ============================================================================= +# Test: All Language Mappings Instantiate +# ============================================================================= + +class TestLanguageMappingsComplete: + """Verify all 32 language mappings can be instantiated.""" + + def test_all_mappings_instantiate(self): + """Every registered mapping class should instantiate without error.""" + failed = [] + passed = [] + + for lang, mapping_class in _MAPPINGS.items(): + try: + instance = mapping_class() + assert instance is not None + assert hasattr(instance, 'get_query_for_concept') + passed.append(lang) + except Exception as e: + failed.append((lang, str(e))) + + assert len(failed) == 0, f"Failed mappings: {failed}" + assert len(passed) == 32, f"Expected 32 mappings, got {len(passed)}" + + def test_all_mappings_have_definition_query(self): + """All mappings should provide a DEFINITION query.""" + missing = [] + for lang, mapping_class in _MAPPINGS.items(): + try: + instance = mapping_class() + query = instance.get_query_for_concept(ConceptType.DEFINITION) + if query is None: + missing.append(lang) + except Exception: + pass # Tested separately + + # Some mappings (text, markdown) may not have DEFINITION queries + assert len(missing) <= 5, f"Too many missing DEFINITION queries: {missing}" + + +# ============================================================================= +# Test: Python Analysis +# ============================================================================= + +class TestPythonAnalysis: + """Test Python code analysis via mappings.""" + + @pytest.fixture + def analyzer(self): + return get_ast_analyzer(reset=True) + + def test_python_function_extraction(self, analyzer): + """Extract Python functions.""" + code = ''' +def hello(name: str) -> str: + """Say hello.""" + return f"Hello {name}" + +async def async_hello(): + pass +''' + result = analyzer.analyze_file('/test.py', 'python', code) + symbols = result.get('symbols', []) + + names = [s.name for s in symbols] + assert 'hello' in names + assert 'async_hello' in names + + def test_python_class_extraction(self, analyzer): + """Extract Python classes and methods.""" + code = ''' +class MyClass: + """A test class.""" + + def __init__(self, value): + self.value = value + + def get_value(self): + return self.value +''' + result = analyzer.analyze_file('/test.py', 'python', code) + symbols = result.get('symbols', []) + + names = [s.name for s in symbols] + assert 'MyClass' in names + + kinds = {s.name: s.kind for s in symbols} + assert kinds.get('MyClass') == 'class' + + def test_python_imports(self, analyzer): + """Extract Python imports.""" + code = ''' +import os +import sys +from pathlib import Path +from typing import List, Dict +''' + result = analyzer.analyze_file('/test.py', 'python', code) + imports = result.get('imports', []) + + modules = [i.module for i in imports] + assert 'os' in modules + assert 'sys' in modules + assert 'pathlib' in modules + assert 'typing' in modules + + def test_python_calls(self, analyzer): + """Extract Python function calls.""" + code = ''' +def main(): + print("Hello") + os.path.join("a", "b") + helper() + +def helper(): + pass +''' + result = analyzer.analyze_file('/test.py', 'python', code) + calls = result.get('calls', []) + + callees = [c.callee for c in calls] + assert 'print' in callees + + +# ============================================================================= +# Test: JavaScript/TypeScript Analysis +# ============================================================================= + +class TestJavaScriptAnalysis: + """Test JavaScript/TypeScript analysis via mappings.""" + + @pytest.fixture + def analyzer(self): + return get_ast_analyzer(reset=True) + + @pytest.mark.skipif(not _TS_AVAILABLE, reason="tree-sitter not available") + def test_javascript_functions(self, analyzer): + """Extract JavaScript functions.""" + code = ''' +function greet(name) { + console.log("Hello " + name); +} + +const arrow = () => { + return 42; +}; +''' + result = analyzer.analyze_file('/test.js', 'javascript', code) + symbols = result.get('symbols', []) + + names = [s.name for s in symbols] + assert 'greet' in names + + @pytest.mark.skipif(not _TS_AVAILABLE, reason="tree-sitter not available") + def test_typescript_imports(self, analyzer): + """Extract TypeScript imports.""" + code = ''' +import { useState, useEffect } from "react"; +import axios from "axios"; +import * as fs from "fs"; +''' + result = analyzer.analyze_file('/test.ts', 'typescript', code) + imports = result.get('imports', []) + + modules = [i.module for i in imports] + assert 'react' in modules + assert 'axios' in modules + assert 'fs' in modules + + +# ============================================================================= +# Test: Go Analysis +# ============================================================================= + +class TestGoAnalysis: + """Test Go analysis via mappings.""" + + @pytest.fixture + def analyzer(self): + return get_ast_analyzer(reset=True) + + @pytest.mark.skipif(not _TS_AVAILABLE, reason="tree-sitter not available") + def test_go_functions(self, analyzer): + """Extract Go functions.""" + code = ''' +package main + +func main() { + fmt.Println("Hello") +} + +func helper(x int) int { + return x * 2 +} +''' + result = analyzer.analyze_file('/test.go', 'go', code) + symbols = result.get('symbols', []) + + names = [s.name for s in symbols] + assert 'main' in names + assert 'helper' in names + + @pytest.mark.skipif(not _TS_AVAILABLE, reason="tree-sitter not available") + def test_go_imports(self, analyzer): + """Extract Go imports.""" + code = ''' +package main + +import ( + "fmt" + "os" + "strings" +) + +func main() {} +''' + result = analyzer.analyze_file('/test.go', 'go', code) + imports = result.get('imports', []) + + modules = [i.module for i in imports] + assert 'fmt' in modules + assert 'os' in modules + assert 'strings' in modules + + +# ============================================================================= +# Test: Rust Analysis +# ============================================================================= + +class TestRustAnalysis: + """Test Rust analysis via mappings.""" + + @pytest.fixture + def analyzer(self): + return get_ast_analyzer(reset=True) + + @pytest.mark.skipif(not _TS_AVAILABLE, reason="tree-sitter not available") + def test_rust_functions(self, analyzer): + """Extract Rust functions.""" + code = ''' +fn main() { + println!("Hello"); +} + +pub fn helper(x: i32) -> i32 { + x * 2 +} +''' + result = analyzer.analyze_file('/test.rs', 'rust', code) + symbols = result.get('symbols', []) + + names = [s.name for s in symbols] + assert 'main' in names + assert 'helper' in names + + @pytest.mark.skipif(not _TS_AVAILABLE, reason="tree-sitter not available") + def test_rust_imports(self, analyzer): + """Extract Rust use statements.""" + code = ''' +use std::io; +use std::collections::HashMap; + +fn main() {} +''' + result = analyzer.analyze_file('/test.rs', 'rust', code) + imports = result.get('imports', []) + + modules = [i.module for i in imports] + assert any('std' in m for m in modules) + + +# ============================================================================= +# Test: Java Analysis +# ============================================================================= + +class TestJavaAnalysis: + """Test Java analysis via mappings.""" + + @pytest.fixture + def analyzer(self): + return get_ast_analyzer(reset=True) + + @pytest.mark.skipif(not _TS_AVAILABLE, reason="tree-sitter not available") + def test_java_class(self, analyzer): + """Extract Java class and methods.""" + code = ''' +public class Hello { + public static void main(String[] args) { + System.out.println("Hello"); + } + + private int helper(int x) { + return x * 2; + } +} +''' + result = analyzer.analyze_file('/Hello.java', 'java', code) + symbols = result.get('symbols', []) + + names = [s.name for s in symbols] + assert 'Hello' in names + assert 'main' in names + + @pytest.mark.skipif(not _TS_AVAILABLE, reason="tree-sitter not available") + def test_java_imports(self, analyzer): + """Extract Java imports.""" + code = ''' +import java.util.List; +import java.util.ArrayList; +import java.io.*; + +public class Test {} +''' + result = analyzer.analyze_file('/Test.java', 'java', code) + imports = result.get('imports', []) + + modules = [i.module for i in imports] + assert 'java.util.List' in modules + assert 'java.util.ArrayList' in modules + + +# ============================================================================= +# Test: C/C++ Analysis +# ============================================================================= + +class TestCppAnalysis: + """Test C/C++ analysis via mappings.""" + + @pytest.fixture + def analyzer(self): + return get_ast_analyzer(reset=True) + + @pytest.mark.skipif(not _TS_AVAILABLE, reason="tree-sitter not available") + def test_cpp_functions(self, analyzer): + """Extract C++ functions.""" + code = ''' +#include + +int main() { + std::cout << "Hello" << std::endl; + return 0; +} + +int helper(int x) { + return x * 2; +} +''' + result = analyzer.analyze_file('/test.cpp', 'cpp', code) + symbols = result.get('symbols', []) + + names = [s.name for s in symbols] + assert 'main' in names + + @pytest.mark.skipif(not _TS_AVAILABLE, reason="tree-sitter not available") + def test_cpp_includes(self, analyzer): + """Extract C++ includes.""" + code = ''' +#include +#include +#include "myheader.h" + +int main() { return 0; } +''' + result = analyzer.analyze_file('/test.cpp', 'cpp', code) + imports = result.get('imports', []) + + modules = [i.module for i in imports] + assert 'iostream' in modules + assert 'vector' in modules + + +# ============================================================================= +# Test: Multi-Language Consistency +# ============================================================================= + +class TestMultiLanguageConsistency: + """Test that analysis is consistent across languages.""" + + @pytest.fixture + def analyzer(self): + return get_ast_analyzer(reset=True) + + @pytest.mark.skipif(not _TS_AVAILABLE, reason="tree-sitter not available") + def test_all_return_correct_types(self, analyzer): + """All analyses should return correct types.""" + test_cases = [ + ('python', 'def foo(): pass'), + ('javascript', 'function foo() {}'), + ('go', 'package main\nfunc foo() {}'), + ('rust', 'fn foo() {}'), + ('java', 'public class Foo {}'), + ('cpp', 'int foo() { return 0; }'), + ] + + for lang, code in test_cases: + result = analyzer.analyze_file(f'/test.{lang}', lang, code) + + assert isinstance(result, dict), f"{lang}: result should be dict" + assert 'symbols' in result, f"{lang}: should have symbols" + assert 'imports' in result, f"{lang}: should have imports" + assert 'calls' in result, f"{lang}: should have calls" + + for sym in result.get('symbols', []): + assert isinstance(sym, CodeSymbol), f"{lang}: symbols should be CodeSymbol" + for imp in result.get('imports', []): + assert isinstance(imp, ImportReference), f"{lang}: imports should be ImportReference" + for call in result.get('calls', []): + assert isinstance(call, CallReference), f"{lang}: calls should be CallReference" + + @pytest.mark.skipif(not _TS_AVAILABLE, reason="tree-sitter not available") + def test_empty_file_handling(self, analyzer): + """Empty files should not crash.""" + for lang in ['python', 'javascript', 'go', 'rust', 'java']: + result = analyzer.analyze_file(f'/empty.{lang}', lang, '') + assert isinstance(result, dict) + + result = analyzer.analyze_file(f'/whitespace.{lang}', lang, ' \n\n ') + assert isinstance(result, dict) + + +# ============================================================================= +# Test: Fallback Behavior +# ============================================================================= + +class TestFallbackBehavior: + """Test fallback to legacy analyzers.""" + + @pytest.fixture + def analyzer(self): + return get_ast_analyzer(reset=True) + + def test_unsupported_language_fallback(self, analyzer): + """Unsupported languages should fall back gracefully.""" + code = 'some unknown code here' + result = analyzer.analyze_file('/test.xyz', 'unknown_language', code) + + # Should return empty analysis, not crash + assert isinstance(result, dict) + assert 'symbols' in result + assert 'imports' in result + + def test_syntax_error_handling(self, analyzer): + """Syntax errors should be handled gracefully.""" + # Malformed Python + code = 'def foo(\n broken syntax here' + result = analyzer.analyze_file('/test.py', 'python', code) + + # Should not crash + assert isinstance(result, dict) + + +# ============================================================================= +# Test: Symbol Metadata +# ============================================================================= + +class TestSymbolMetadata: + """Test that symbol metadata is extracted correctly.""" + + @pytest.fixture + def analyzer(self): + return get_ast_analyzer(reset=True) + + def test_python_symbol_metadata(self, analyzer): + """Python symbols should have rich metadata.""" + code = ''' +@decorator +def my_function(a: int, b: str) -> bool: + """This is the docstring.""" + return True +''' + result = analyzer.analyze_file('/test.py', 'python', code) + symbols = result.get('symbols', []) + + func = next((s for s in symbols if s.name == 'my_function'), None) + assert func is not None + assert func.kind == 'function' + assert func.start_line > 0 + assert func.end_line >= func.start_line + + def test_symbol_line_numbers(self, analyzer): + """Symbol line numbers should be accurate.""" + code = '''# Line 1 +# Line 2 +def foo(): # Line 3 + pass # Line 4 +# Line 5 +def bar(): # Line 6 + pass # Line 7 +''' + result = analyzer.analyze_file('/test.py', 'python', code) + symbols = result.get('symbols', []) + + foo = next((s for s in symbols if s.name == 'foo'), None) + bar = next((s for s in symbols if s.name == 'bar'), None) + + assert foo is not None + assert bar is not None + assert foo.start_line == 3 + assert bar.start_line == 6 + + +# ============================================================================= +# Run tests +# ============================================================================= + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/test_cast_chunker.py b/tests/test_cast_chunker.py new file mode 100644 index 00000000..cfebccb2 --- /dev/null +++ b/tests/test_cast_chunker.py @@ -0,0 +1,182 @@ +"""Tests for scripts/ingest/cast_chunker.py - CAST+ Hybrid Chunker.""" + +import pytest +from scripts.ingest.cast_chunker import ( + CASTPlusConfig, + CASTPlusChunker, + ConceptType, + SemanticChunk, + ChunkResult, + COMPATIBLE_PAIRS, + chunk_cast_plus, + get_cast_chunker, +) + + +class TestCASTPlusConfig: + """Tests for CASTPlusConfig dataclass.""" + + def test_default_values(self): + """Test default configuration values.""" + config = CASTPlusConfig() + assert config.max_chunk_size == 1200 + assert config.min_chunk_size == 50 + assert config.safe_token_limit == 6000 + assert config.merge_threshold == 0.8 + assert config.deduplicate is True + + def test_custom_values(self): + """Test custom configuration values.""" + config = CASTPlusConfig( + max_chunk_size=2000, + min_chunk_size=100, + deduplicate=False, + ) + assert config.max_chunk_size == 2000 + assert config.min_chunk_size == 100 + assert config.deduplicate is False + + +class TestConceptType: + """Tests for ConceptType enum.""" + + def test_concept_values(self): + """Test concept type values.""" + assert ConceptType.DEFINITION.value == "definition" + assert ConceptType.BLOCK.value == "block" + assert ConceptType.COMMENT.value == "comment" + assert ConceptType.IMPORT.value == "import" + assert ConceptType.STRUCTURE.value == "structure" + + +class TestCompatiblePairs: + """Tests for compatible concept pairs.""" + + def test_comment_definition_compatible(self): + """Test that COMMENT and DEFINITION are compatible.""" + assert (ConceptType.COMMENT, ConceptType.DEFINITION) in COMPATIBLE_PAIRS + assert (ConceptType.DEFINITION, ConceptType.COMMENT) in COMPATIBLE_PAIRS + + def test_block_definition_not_compatible(self): + """Test that BLOCK and DEFINITION are NOT compatible.""" + assert (ConceptType.BLOCK, ConceptType.DEFINITION) not in COMPATIBLE_PAIRS + + +class TestSemanticChunk: + """Tests for SemanticChunk dataclass.""" + + def test_post_init_computes_metrics(self): + """Test that __post_init__ computes metrics.""" + chunk = SemanticChunk( + concept=ConceptType.DEFINITION, + name="foo", + content="def foo(): pass", + start_line=1, + end_line=1, + ) + assert chunk.non_whitespace_chars > 0 + assert chunk.estimated_tokens > 0 + assert 0.0 <= chunk.density_score <= 1.0 + + def test_empty_content_density(self): + """Test density calculation with empty content.""" + chunk = SemanticChunk( + concept=ConceptType.DEFINITION, + name="empty", + content="", + start_line=1, + end_line=1, + ) + assert chunk.density_score == 0.0 + + +class TestCASTPlusChunker: + """Tests for CASTPlusChunker class.""" + + def test_initialization(self): + """Test chunker initialization.""" + chunker = CASTPlusChunker() + assert chunker.config is not None + assert isinstance(chunker.config, CASTPlusConfig) + + def test_custom_config(self): + """Test chunker with custom config.""" + config = CASTPlusConfig(max_chunk_size=500) + chunker = CASTPlusChunker(config) + assert chunker.config.max_chunk_size == 500 + + def test_chunk_simple_function(self): + """Test chunking a simple function.""" + chunker = CASTPlusChunker() + content = '''def hello(): + """Say hello.""" + print("Hello, World!") +''' + results = chunker.chunk(content, "python") + assert len(results) >= 1 + assert all(isinstance(r, ChunkResult) for r in results) + + def test_chunk_to_dicts(self): + """Test chunk_to_dicts returns dictionaries.""" + chunker = CASTPlusChunker() + content = "def foo(): pass" + results = chunker.chunk_to_dicts(content, "python") + assert all(isinstance(r, dict) for r in results) + if results: + assert "text" in results[0] + # Uses 'start' and 'end' keys, not 'start_line' + assert "start" in results[0] or "start_line" in results[0] + + def test_deduplication_enabled(self): + """Test that deduplication removes duplicates.""" + config = CASTPlusConfig(deduplicate=True) + chunker = CASTPlusChunker(config) + # Content with duplicate blocks + content = '''x = 1 +x = 1 +''' + results = chunker.chunk(content, "python") + # Should have fewer chunks due to dedup + assert len(results) >= 1 + + def test_deduplication_disabled(self): + """Test that deduplication can be disabled.""" + config = CASTPlusConfig(deduplicate=False) + chunker = CASTPlusChunker(config) + content = "x = 1" + results = chunker.chunk(content, "python") + assert len(results) >= 1 + + +class TestChunkCastPlus: + """Tests for chunk_cast_plus convenience function.""" + + def test_basic_usage(self): + """Test basic usage of chunk_cast_plus.""" + content = "def foo(): pass" + results = chunk_cast_plus(content, "python") + assert isinstance(results, list) + assert all(isinstance(r, dict) for r in results) + + def test_with_custom_config(self): + """Test with custom config.""" + config = CASTPlusConfig(max_chunk_size=500) + content = "def foo(): pass" + results = chunk_cast_plus(content, "python", config=config) + assert isinstance(results, list) + + +class TestGetCastChunker: + """Tests for get_cast_chunker factory function.""" + + def test_returns_chunker(self): + """Test that get_cast_chunker returns a chunker.""" + chunker = get_cast_chunker() + assert isinstance(chunker, CASTPlusChunker) + + def test_with_custom_config(self): + """Test with custom config returns new instance.""" + config = CASTPlusConfig(max_chunk_size=999) + chunker = get_cast_chunker(config) + assert chunker.config.max_chunk_size == 999 + diff --git a/tests/test_chunk_deduplication_core.py b/tests/test_chunk_deduplication_core.py new file mode 100644 index 00000000..7fd794cb --- /dev/null +++ b/tests/test_chunk_deduplication_core.py @@ -0,0 +1,197 @@ +"""Tests for scripts/ingest/chunk_deduplication.py - O(n log n) deduplication.""" + +import pytest +from scripts.ingest.chunk_deduplication import ( + normalize_content, + get_chunk_specificity, + deduplicate_chunks, + deduplicate_semantic_chunks, + TYPE_WEIGHTS, +) + + +class TestNormalizeContent: + """Tests for content normalization.""" + + def test_strips_whitespace(self): + """Test that leading/trailing whitespace is stripped.""" + assert normalize_content(" hello ") == "hello" + assert normalize_content("\n\nhello\n\n") == "hello" + + def test_normalizes_line_endings(self): + """Test that different line endings are normalized.""" + assert normalize_content("a\r\nb") == "a\nb" + assert normalize_content("a\rb") == "a\nb" + assert normalize_content("a\r\n\rb") == "a\n\nb" + + def test_empty_string(self): + """Test empty string handling.""" + assert normalize_content("") == "" + assert normalize_content(" ") == "" + + +class TestGetChunkSpecificity: + """Tests for chunk specificity ranking.""" + + def test_function_has_high_specificity(self): + """Test that function chunks have high specificity.""" + chunk = {"chunk_type": "function"} + assert get_chunk_specificity(chunk) == 4 + + def test_block_has_low_specificity(self): + """Test that block chunks have low specificity.""" + chunk = {"chunk_type": "block"} + assert get_chunk_specificity(chunk) == 1 + + def test_definition_concept_type(self): + """Test DEFINITION concept type (from CAST+).""" + chunk = {"chunk_type": "DEFINITION"} + assert get_chunk_specificity(chunk) == 4 + + def test_unknown_type_returns_zero(self): + """Test unknown type returns 0 (lowest specificity).""" + chunk = {"chunk_type": "unknown_type"} + assert get_chunk_specificity(chunk) == 0 + + def test_concept_key_fallback(self): + """Test fallback to 'concept' key.""" + chunk = {"concept": "function"} + assert get_chunk_specificity(chunk) == 4 + + def test_type_key_fallback(self): + """Test fallback to 'type' key.""" + chunk = {"type": "class"} + assert get_chunk_specificity(chunk) == 4 + + def test_enum_value_handling(self): + """Test handling of enum-like objects with .value.""" + from enum import Enum + class MockConcept(Enum): + DEFINITION = "definition" + chunk = {"chunk_type": MockConcept.DEFINITION} + assert get_chunk_specificity(chunk) == 4 + + +class TestDeduplicateChunks: + """Tests for deduplicate_chunks function.""" + + def test_empty_input(self): + """Test empty input returns empty list.""" + assert deduplicate_chunks([]) == [] + + def test_no_duplicates(self): + """Test chunks without duplicates are preserved.""" + chunks = [ + {"code": "def foo(): pass", "chunk_type": "function"}, + {"code": "def bar(): pass", "chunk_type": "function"}, + ] + result = deduplicate_chunks(chunks) + assert len(result) == 2 + + def test_exact_duplicates_removed(self): + """Test exact duplicate content is removed.""" + chunks = [ + {"code": "def foo(): pass", "chunk_type": "function"}, + {"code": "def foo(): pass", "chunk_type": "function"}, + ] + result = deduplicate_chunks(chunks) + assert len(result) == 1 + + def test_keeps_higher_specificity(self): + """Test that higher specificity chunk is kept on duplicate.""" + chunks = [ + {"code": "x = 1", "chunk_type": "block"}, # specificity 1 + {"code": "x = 1", "chunk_type": "function"}, # specificity 4 + ] + result = deduplicate_chunks(chunks) + assert len(result) == 1 + assert result[0]["chunk_type"] == "function" + + def test_vue_language_exemption(self): + """Test Vue language is exempt from deduplication.""" + chunks = [ + {"code": "same content", "chunk_type": "block"}, + {"code": "same content", "chunk_type": "block"}, + ] + result = deduplicate_chunks(chunks, language="vue") + assert len(result) == 2 + + def test_haskell_language_exemption(self): + """Test Haskell language is exempt from deduplication.""" + chunks = [ + {"code": "same content", "chunk_type": "block"}, + {"code": "same content", "chunk_type": "block"}, + ] + result = deduplicate_chunks(chunks, language="haskell") + assert len(result) == 2 + + def test_substring_removal(self): + """Test that block substrings of definitions are removed.""" + chunks = [ + { + "code": "def foo():\n x = 1\n return x", + "chunk_type": "function", + "start_line": 1, + "end_line": 3, + }, + { + "code": "x = 1", + "chunk_type": "block", + "start_line": 2, + "end_line": 2, + }, + ] + result = deduplicate_chunks(chunks) + # Block should be removed as it's a substring of the function + assert len(result) == 1 + assert result[0]["chunk_type"] == "function" + + def test_custom_content_key(self): + """Test custom content key.""" + chunks = [ + {"text": "same", "chunk_type": "block"}, + {"text": "same", "chunk_type": "block"}, + ] + result = deduplicate_chunks(chunks, content_key="text") + assert len(result) == 1 + + def test_whitespace_normalization_in_dedup(self): + """Test that whitespace differences don't prevent dedup.""" + chunks = [ + {"code": "def foo(): pass", "chunk_type": "function"}, + {"code": "def foo(): pass ", "chunk_type": "function"}, # trailing space + ] + result = deduplicate_chunks(chunks) + assert len(result) == 1 + + +class TestDeduplicateSemanticChunks: + """Tests for deduplicate_semantic_chunks function.""" + + def test_empty_input(self): + """Test empty input returns empty list.""" + assert deduplicate_semantic_chunks([]) == [] + + def test_preserves_original_objects(self): + """Test that original objects are returned, not copies.""" + from dataclasses import dataclass + from enum import Enum + + class ConceptType(Enum): + DEFINITION = "definition" + + @dataclass + class MockChunk: + content: str + start_line: int + end_line: int + concept: ConceptType + + chunk1 = MockChunk("def foo(): pass", 1, 1, ConceptType.DEFINITION) + chunk2 = MockChunk("def bar(): pass", 2, 2, ConceptType.DEFINITION) + + result = deduplicate_semantic_chunks([chunk1, chunk2]) + assert len(result) == 2 + assert chunk1 in result + assert chunk2 in result + diff --git a/tests/test_elbow_detection.py b/tests/test_elbow_detection.py new file mode 100644 index 00000000..2296853b --- /dev/null +++ b/tests/test_elbow_detection.py @@ -0,0 +1,170 @@ +"""Tests for scripts/hybrid/elbow_detection.py - Kneedle algorithm and adaptive thresholds.""" + +import pytest +from scripts.hybrid.elbow_detection import ( + find_elbow_kneedle, + compute_elbow_threshold, + filter_by_elbow, +) + + +class TestFindElbowKneedle: + """Tests for the Kneedle algorithm implementation.""" + + def test_clear_elbow_detected(self): + """Test detection of a clear elbow point.""" + # Clear drop after index 2 + scores = [0.95, 0.92, 0.88, 0.45, 0.42, 0.40] + elbow_idx = find_elbow_kneedle(scores) + assert elbow_idx is not None + # Elbow should be around the drop point + assert 1 <= elbow_idx <= 3 + + def test_too_few_points_returns_none(self): + """Test that fewer than 3 points returns None.""" + assert find_elbow_kneedle([0.9]) is None + assert find_elbow_kneedle([0.9, 0.8]) is None + assert find_elbow_kneedle([]) is None + + def test_identical_scores_returns_none(self): + """Test that identical scores return None (no elbow).""" + scores = [0.5, 0.5, 0.5, 0.5, 0.5] + assert find_elbow_kneedle(scores) is None + + def test_linear_decrease_minimal_elbow(self): + """Test linear decrease - may or may not detect elbow.""" + scores = [1.0, 0.8, 0.6, 0.4, 0.2] + # Linear decrease has no clear elbow + result = find_elbow_kneedle(scores) + # Should return None or a middle index + assert result is None or 0 <= result < len(scores) + + def test_sharp_drop_at_end(self): + """Test sharp drop at the end of the curve.""" + scores = [0.95, 0.94, 0.93, 0.92, 0.10] + elbow_idx = find_elbow_kneedle(scores) + assert elbow_idx is not None + # Elbow should be near the drop + assert elbow_idx >= 2 + + def test_gradual_then_sharp_drop(self): + """Test gradual decrease followed by sharp drop.""" + scores = [0.99, 0.98, 0.97, 0.96, 0.30, 0.29, 0.28] + elbow_idx = find_elbow_kneedle(scores) + assert elbow_idx is not None + # Elbow should be around index 3-4 + assert 2 <= elbow_idx <= 5 + + +class TestComputeElbowThreshold: + """Tests for compute_elbow_threshold function.""" + + def test_empty_input_returns_default(self): + """Test empty input returns default threshold.""" + assert compute_elbow_threshold([]) == 0.5 + # Single dict with no score extracts 0.0, which is a valid score + result = compute_elbow_threshold([{}]) + assert 0.0 <= result <= 0.5 + + def test_with_raw_scores(self): + """Test with raw float scores.""" + scores = [0.95, 0.88, 0.45, 0.42] + threshold = compute_elbow_threshold(scores) + assert 0.0 <= threshold <= 1.0 + # Threshold should be around the elbow + assert threshold >= 0.40 + + def test_with_dict_chunks(self): + """Test with dict chunks containing score key.""" + chunks = [ + {"score": 0.95}, + {"score": 0.88}, + {"score": 0.45}, + {"score": 0.42}, + ] + threshold = compute_elbow_threshold(chunks) + assert 0.0 <= threshold <= 1.0 + + def test_with_fallback_score_key(self): + """Test fallback to rerank_score when score is missing.""" + chunks = [ + {"rerank_score": 0.95}, + {"rerank_score": 0.45}, + {"rerank_score": 0.20}, + ] + threshold = compute_elbow_threshold(chunks, score_key="score") + assert 0.0 <= threshold <= 1.0 + + def test_zero_scores_handled_correctly(self): + """Test that 0.0 scores are handled correctly (not treated as missing).""" + chunks = [ + {"score": 0.95}, + {"score": 0.0}, # Real zero score + {"score": 0.0}, + ] + threshold = compute_elbow_threshold(chunks) + # Should not crash and should return valid threshold + assert 0.0 <= threshold <= 1.0 + + def test_custom_score_key(self): + """Test with custom score key.""" + chunks = [ + {"my_score": 0.9}, + {"my_score": 0.5}, + {"my_score": 0.1}, + ] + threshold = compute_elbow_threshold(chunks, score_key="my_score") + assert 0.0 <= threshold <= 1.0 + + +class TestFilterByElbow: + """Tests for filter_by_elbow function.""" + + def test_empty_results(self): + """Test empty input returns empty list.""" + assert filter_by_elbow([]) == [] + + def test_filters_below_threshold(self): + """Test that results below threshold are filtered.""" + results = [ + {"id": 1, "score": 0.95}, + {"id": 2, "score": 0.90}, + {"id": 3, "score": 0.30}, # Below elbow + {"id": 4, "score": 0.25}, # Below elbow + ] + filtered = filter_by_elbow(results) + # Should keep high-scoring results + assert len(filtered) >= 1 + assert all(r["score"] >= 0.25 for r in filtered) + + def test_min_results_guaranteed(self): + """Test that min_results are always returned.""" + results = [ + {"id": 1, "score": 0.95}, + {"id": 2, "score": 0.10}, + {"id": 3, "score": 0.05}, + ] + filtered = filter_by_elbow(results, min_results=2) + assert len(filtered) >= 2 + + def test_zero_score_not_treated_as_missing(self): + """Test that 0.0 score is not treated as missing.""" + results = [ + {"id": 1, "score": 0.9}, + {"id": 2, "score": 0.0}, # Real zero, not missing + {"id": 3, "score": 0.0}, + ] + # Should not crash + filtered = filter_by_elbow(results) + assert isinstance(filtered, list) + + def test_fallback_score_key_used(self): + """Test that fallback score key is used when primary is missing.""" + results = [ + {"id": 1, "rerank_score": 0.95}, + {"id": 2, "rerank_score": 0.50}, + {"id": 3, "rerank_score": 0.10}, + ] + filtered = filter_by_elbow(results, score_key="score", fallback_score_key="rerank_score") + assert len(filtered) >= 1 + diff --git a/tests/test_termination.py b/tests/test_termination.py new file mode 100644 index 00000000..7e3c204b --- /dev/null +++ b/tests/test_termination.py @@ -0,0 +1,201 @@ +"""Tests for scripts/hybrid/termination.py - Smart termination conditions.""" + +import time +import pytest +from scripts.hybrid.termination import TerminationConfig, TerminationChecker + + +class TestTerminationConfig: + """Tests for TerminationConfig dataclass.""" + + def test_default_values(self): + """Test default configuration values.""" + config = TerminationConfig() + assert config.time_limit == 5.0 + assert config.result_limit == 500 + assert config.min_candidates_for_expansion == 5 + assert config.fixed_degradation_threshold == 0.15 + assert config.min_relevance_score == 0.3 + assert config.top_n_to_track == 5 + assert config.use_page_hinkley is True + assert config.page_hinkley_threshold == 0.5 + + def test_custom_values(self): + """Test custom configuration values.""" + config = TerminationConfig( + time_limit=10.0, + result_limit=1000, + min_candidates_for_expansion=10, + ) + assert config.time_limit == 10.0 + assert config.result_limit == 1000 + assert config.min_candidates_for_expansion == 10 + + +class TestTerminationChecker: + """Tests for TerminationChecker class.""" + + def test_initialization(self): + """Test checker initialization.""" + checker = TerminationChecker() + assert checker.iteration == 0 + assert checker.tracked_chunk_scores == {} + assert checker.elapsed() >= 0 + + def test_reset(self): + """Test reset clears state.""" + checker = TerminationChecker() + checker.iteration = 5 + checker.tracked_chunk_scores = {"a": 0.9} + checker.reset() + assert checker.iteration == 0 + assert checker.tracked_chunk_scores == {} + + def test_time_limit_termination(self): + """Test termination on time limit.""" + config = TerminationConfig(time_limit=0.01) # 10ms + checker = TerminationChecker(config) + + # Wait for time limit + time.sleep(0.02) + + results = [{"chunk_id": "a", "score": 0.9} for _ in range(10)] + should_terminate, reason = checker.check(results) + + assert should_terminate is True + assert reason == "time_limit" + + def test_result_limit_termination(self): + """Test termination on result limit.""" + config = TerminationConfig(result_limit=5) + checker = TerminationChecker(config) + + results = [{"chunk_id": f"c{i}", "score": 0.9} for i in range(10)] + should_terminate, reason = checker.check(results) + + assert should_terminate is True + assert reason == "result_limit" + + def test_insufficient_candidates_termination(self): + """Test termination when not enough high-scoring candidates.""" + config = TerminationConfig(min_candidates_for_expansion=5) + checker = TerminationChecker(config) + + # Only 3 results with positive scores + results = [ + {"chunk_id": "a", "score": 0.9}, + {"chunk_id": "b", "score": 0.8}, + {"chunk_id": "c", "score": 0.7}, + ] + should_terminate, reason = checker.check(results) + + assert should_terminate is True + assert reason == "insufficient_candidates" + + def test_score_degradation_termination(self): + """Test termination on score degradation via Page-Hinkley.""" + config = TerminationConfig( + fixed_degradation_threshold=0.1, + top_n_to_track=3, + min_candidates_for_expansion=1, + min_relevance_score=0.0, + use_page_hinkley=True, + page_hinkley_threshold=0.3, + min_iterations_before_stop=2, + ) + checker = TerminationChecker(config) + + # First iteration - establish baseline + results1 = [ + {"chunk_id": "a", "score": 0.9}, + {"chunk_id": "b", "score": 0.8}, + {"chunk_id": "c", "score": 0.7}, + ] + should_terminate, reason = checker.check(results1) + assert should_terminate is False + + # Second iteration - scores start dropping + results2 = [ + {"chunk_id": "a", "score": 0.7}, + {"chunk_id": "b", "score": 0.6}, + {"chunk_id": "c", "score": 0.5}, + ] + should_terminate, reason = checker.check(results2) + assert should_terminate is False + + # Third iteration - continued drop triggers Page-Hinkley + results3 = [ + {"chunk_id": "a", "score": 0.4}, + {"chunk_id": "b", "score": 0.3}, + {"chunk_id": "c", "score": 0.2}, + ] + should_terminate, reason = checker.check(results3) + + assert should_terminate is True + assert reason in ("score_drift_detected", "score_degradation") + + def test_min_relevance_termination(self): + """Test termination when min relevance score is too low.""" + config = TerminationConfig( + min_relevance_score=0.5, + top_n_to_track=3, + min_candidates_for_expansion=1, + ) + checker = TerminationChecker(config) + + # Results with low minimum score in top-N + results = [ + {"chunk_id": "a", "score": 0.9}, + {"chunk_id": "b", "score": 0.6}, + {"chunk_id": "c", "score": 0.3}, # Below min_relevance_score + ] + should_terminate, reason = checker.check(results) + + assert should_terminate is True + assert reason == "min_relevance" + + def test_no_termination_when_conditions_not_met(self): + """Test that checker continues when no conditions are met.""" + config = TerminationConfig( + time_limit=60.0, + result_limit=1000, + min_candidates_for_expansion=3, + min_relevance_score=0.3, + ) + checker = TerminationChecker(config) + + results = [ + {"chunk_id": "a", "score": 0.9}, + {"chunk_id": "b", "score": 0.8}, + {"chunk_id": "c", "score": 0.7}, + {"chunk_id": "d", "score": 0.6}, + {"chunk_id": "e", "score": 0.5}, + ] + should_terminate, reason = checker.check(results) + + assert should_terminate is False + assert reason == "" + + def test_get_stats(self): + """Test get_stats returns correct information.""" + checker = TerminationChecker() + results = [{"chunk_id": "a", "score": 0.9} for _ in range(10)] + checker.check(results) + checker.check(results) + + stats = checker.get_stats() + assert stats["iterations"] == 2 + assert "elapsed_seconds" in stats + assert stats["elapsed_seconds"] >= 0 + + def test_iteration_counter_increments(self): + """Test that iteration counter increments on each check.""" + checker = TerminationChecker() + results = [{"chunk_id": f"c{i}", "score": 0.9} for i in range(10)] + + checker.check(results) + assert checker.iteration == 1 + + checker.check(results) + assert checker.iteration == 2 + diff --git a/tests/test_workspace_state.py b/tests/test_workspace_state.py index 6d812726..d21991c1 100644 --- a/tests/test_workspace_state.py +++ b/tests/test_workspace_state.py @@ -442,29 +442,34 @@ class TestConfigDrift: """Tests for indexing config drift detection.""" def test_get_indexing_config_snapshot_includes_graph_edges(self, ws_module, monkeypatch): - """Verify index_graph_edges key exists in snapshot with default value True.""" - # Clear any existing env vars to test defaults - monkeypatch.delenv("INDEX_GRAPH_EDGES", raising=False) + """Verify index_graph_edges key exists in snapshot and is always True. + Symbol graph (Qdrant flat graph) is always on - this value is no longer + configurable via env var. Use NEO4J_GRAPH=1 to enable Neo4j backend instead. + """ snapshot = ws_module.get_indexing_config_snapshot() assert "index_graph_edges" in snapshot, "index_graph_edges should be in config snapshot" - assert snapshot["index_graph_edges"] is True, "Default value for index_graph_edges should be True" + assert snapshot["index_graph_edges"] is True, "index_graph_edges should always be True (always on)" - def test_get_indexing_config_snapshot_respects_env_var(self, ws_module, monkeypatch): - """Verify INDEX_GRAPH_EDGES env var is respected in snapshot.""" - # Test with False + def test_get_indexing_config_snapshot_graph_edges_always_true(self, ws_module, monkeypatch): + """Verify index_graph_edges is always True regardless of env var (now unconditional).""" + # Even with env var set to 0, index_graph_edges should be True (always on) monkeypatch.setenv("INDEX_GRAPH_EDGES", "0") snapshot = ws_module.get_indexing_config_snapshot() - assert snapshot["index_graph_edges"] is False, "INDEX_GRAPH_EDGES=0 should set index_graph_edges to False" + assert snapshot["index_graph_edges"] is True, "index_graph_edges should always be True (env var ignored)" - # Test with True + # Same with env var set to 1 monkeypatch.setenv("INDEX_GRAPH_EDGES", "1") snapshot = ws_module.get_indexing_config_snapshot() - assert snapshot["index_graph_edges"] is True, "INDEX_GRAPH_EDGES=1 should set index_graph_edges to True" + assert snapshot["index_graph_edges"] is True, "index_graph_edges should always be True" def test_config_drift_classifies_graph_edges_as_recreate(self, ws_module): - """Verify that changing INDEX_GRAPH_EDGES triggers recreate drift.""" + """Verify that changing index_graph_edges triggers recreate drift. + + Note: index_graph_edges is now always True, but drift rules still exist + for backwards compatibility with existing indexes that may have False. + """ from scripts import indexing_admin # Verify the drift rule exists and is classified as "recreate" @@ -473,23 +478,15 @@ def test_config_drift_classifies_graph_edges_as_recreate(self, ws_module): assert indexing_admin.CONFIG_DRIFT_RULES["index_graph_edges"] == "recreate", \ "index_graph_edges drift should be classified as 'recreate'" - def test_config_drift_graph_edges_true_to_false(self, ws_module): - """Verify drift from True->False is classified as recreate.""" - from scripts import indexing_admin - - old_config = {"index_graph_edges": True} - new_config = {"index_graph_edges": False} - - # The actual drift detection is more complex, but we can verify the rule - rule = indexing_admin.CONFIG_DRIFT_RULES.get("index_graph_edges") - assert rule == "recreate", "Changing index_graph_edges should require recreate" + def test_config_drift_graph_edges_legacy_false_to_true(self, ws_module): + """Verify drift from legacy False->True is classified as recreate. - def test_config_drift_graph_edges_false_to_true(self, ws_module): - """Verify drift from False->True is classified as recreate.""" + This handles migration from old indexes where graph edges were disabled. + """ from scripts import indexing_admin - old_config = {"index_graph_edges": False} - new_config = {"index_graph_edges": True} + old_config = {"index_graph_edges": False} # Legacy: was disabled + new_config = {"index_graph_edges": True} # Now: always on # The actual drift detection is more complex, but we can verify the rule rule = indexing_admin.CONFIG_DRIFT_RULES.get("index_graph_edges") diff --git a/vscode-extension/context-engine-uploader/README.md b/vscode-extension/context-engine-uploader/README.md index c84a79b6..9b4f64dd 100644 --- a/vscode-extension/context-engine-uploader/README.md +++ b/vscode-extension/context-engine-uploader/README.md @@ -61,7 +61,10 @@ MCP bridge (ctx-mcp-bridge) & MCP config lifecycle - **Centralized logging & health:** when the bridge process runs once per workspace you get a single stream of logs (`Context Engine Upload` output) and a single port to probe for health checks instead of multiple MCP child processes per IDE. - When you run **`Write MCP Config`**, the extension: - Writes `.mcp.json` in the workspace for Claude Code. - - Optionally writes Windsurf’s `mcp_config.json` (when `mcpWindsurfEnabled=true`). + - Optionally writes Windsurf's `mcp_config.json` (when `mcpWindsurfEnabled=true`). + - Optionally writes Augment's `settings.json` (when `mcpAugmentEnabled=true`). + - Optionally writes Antigravity's `mcp_config.json` (when `mcpAntigravityEnabled=true`). + - Optionally writes Cursor's `~/.cursor/mcp.json` (when `mcpCursorEnabled=true`). - Optionally scaffolds `ctx_config.json` + `.env` (when `scaffoldCtxConfig=true`). - The effective wiring mode is determined by the two MCP settings: - `mcpServerMode = bridge`, `mcpTransportMode = sse-remote` → **bridge-stdio**. @@ -78,6 +81,18 @@ MCP bridge (ctx-mcp-bridge) & MCP config lifecycle - In **stdio or direct modes**, the HTTP bridge is **not** auto-started; only the explicit `Start MCP HTTP Bridge` command will launch it. - Bridge settings are **workspace-scoped**, so different workspaces can choose different modes and ports (e.g., one workspace using stdio bridge, another using HTTP bridge on a different port). +Cursor Integration +------------------ + +Enable `mcpCursorEnabled` in settings to write MCP config to `~/.cursor/mcp.json`. + +**Caveats:** +- Cursor uses a **global** MCP config at `~/.cursor/mcp.json` (not per-project like Claude's `.mcp.json`). +- After updating the config, you must **restart Cursor** for changes to take effect. +- Cursor's MCP support requires the `http` transport mode. Set `mcpTransportMode` to `http`. +- If using bridge mode, ensure the HTTP bridge is running (`autoStartMcpBridge=true`). +- Custom config path: set `cursorMcpPath` to override the default `~/.cursor/mcp.json` location. + Optional auth with the MCP bridge (PoC) -------------------------------------- diff --git a/vscode-extension/context-engine-uploader/assets/logo.jpeg b/vscode-extension/context-engine-uploader/assets/logo.jpeg new file mode 100644 index 00000000..28ecace8 Binary files /dev/null and b/vscode-extension/context-engine-uploader/assets/logo.jpeg differ diff --git a/vscode-extension/context-engine-uploader/commands.js b/vscode-extension/context-engine-uploader/commands.js index 29f97859..197a0658 100644 --- a/vscode-extension/context-engine-uploader/commands.js +++ b/vscode-extension/context-engine-uploader/commands.js @@ -220,6 +220,14 @@ function registerExtensionCommands(deps) { } })); + disposables.push(vscode.commands.registerCommand('contextEngineUploader.writeMcpConfigCursor', () => { + try { + requireDep(writeMcpConfig, 'writeMcpConfig')({ targets: ['cursor'] }).catch(error => handleCatch(error, 'Failed to write Cursor MCP config')); + } catch (error) { + handleCatch(error, 'Failed to write Cursor MCP config'); + } + })); + // Onboarding/Stack commands disposables.push(vscode.commands.registerCommand('contextEngineUploader.cloneAndStartStack', async () => { try { diff --git a/vscode-extension/context-engine-uploader/dashboard.js b/vscode-extension/context-engine-uploader/dashboard.js index 69c6bc2c..3de35cb7 100644 --- a/vscode-extension/context-engine-uploader/dashboard.js +++ b/vscode-extension/context-engine-uploader/dashboard.js @@ -95,13 +95,14 @@ class DashboardViewProvider { _getHtmlContent(webview) { const state = this._getState(); const nonce = getNonce(); + const logoUri = webview.asWebviewUri(vscode.Uri.joinPath(this._extensionUri, 'assets', 'logo.jpeg')); return ` - + Context Engine Dashboard @@ -194,7 +196,7 @@ class SettingsWebviewProvider {