diff --git a/graphify/extract.py b/graphify/extract.py index c5ceddc0..70786fba 100644 --- a/graphify/extract.py +++ b/graphify/extract.py @@ -1,9 +1,13 @@ -"""Deterministic structural extraction from Python code using tree-sitter. Outputs nodes+edges dicts.""" +"""Deterministic structural extraction from source code using tree-sitter. Outputs nodes+edges dicts.""" from __future__ import annotations +import importlib import json +import os import re import sys +from dataclasses import dataclass, field from pathlib import Path +from typing import Callable, Any from .cache import load_cached, save_cached @@ -14,16 +18,673 @@ def _make_id(*parts: str) -> str: return cleaned.strip("_").lower() -def extract_python(path: Path) -> dict: - """Extract classes, functions, and imports from a .py file via tree-sitter AST.""" +# ── LanguageConfig dataclass ───────────────────────────────────────────────── + +@dataclass +class LanguageConfig: + ts_module: str # e.g. "tree_sitter_python" + ts_language_fn: str = "language" # attr to call: e.g. tslang.language() + + class_types: frozenset = frozenset() + function_types: frozenset = frozenset() + import_types: frozenset = frozenset() + reexport_types: frozenset = frozenset() # e.g. {"export_statement"} for JS/TS barrel re-exports + call_types: frozenset = frozenset() + static_prop_types: frozenset = frozenset() + helper_fn_names: frozenset = frozenset() + container_bind_methods: frozenset = frozenset() + event_listener_properties: frozenset = frozenset() + + # Name extraction + name_field: str = "name" + name_fallback_child_types: tuple = () + + # Body detection + body_field: str = "body" + body_fallback_child_types: tuple = () # e.g. ("declaration_list", "compound_statement") + + # Call name extraction + call_function_field: str = "function" # field on call node for callee + call_accessor_node_types: frozenset = frozenset() # member/attribute nodes + call_accessor_field: str = "attribute" # field on accessor for method name + + # Stop recursion at these types in walk_calls + function_boundary_types: frozenset = frozenset() + + # Import handler: called for import nodes instead of generic handling + import_handler: Callable | None = None + + # Optional custom name resolver for functions (C, C++ declarator unwrapping) + resolve_function_name_fn: Callable | None = None + + # Extra label formatting for functions: if True, functions get "name()" label + function_label_parens: bool = True + + # Extra walk hook called after generic dispatch (for JS arrow functions, C# namespaces, etc.) + extra_walk_fn: Callable | None = None + + +# ── Generic helpers ─────────────────────────────────────────────────────────── + +def _read_text(node, source: bytes) -> str: + return source[node.start_byte:node.end_byte].decode("utf-8", errors="replace") + + +def _resolve_name(node, source: bytes, config: LanguageConfig) -> str | None: + """Get the name from a node using config.name_field, falling back to child types.""" + if config.resolve_function_name_fn is not None: + # For C/C++ where the name is inside a declarator + return None # caller handles this separately + n = node.child_by_field_name(config.name_field) + if n: + return _read_text(n, source) + for child in node.children: + if child.type in config.name_fallback_child_types: + return _read_text(child, source) + return None + + +def _find_body(node, config: LanguageConfig): + """Find the body node using config.body_field, falling back to child types.""" + b = node.child_by_field_name(config.body_field) + if b: + return b + for child in node.children: + if child.type in config.body_fallback_child_types: + return child + return None + + +# ── Import handlers ─────────────────────────────────────────────────────────── + +def _import_python(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: + t = node.type + if t == "import_statement": + for child in node.children: + if child.type in ("dotted_name", "aliased_import"): + raw = _read_text(child, source) + module_name = raw.split(" as ")[0].strip().lstrip(".") + tgt_nid = _make_id(module_name) + edges.append({ + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + }) + elif t == "import_from_statement": + module_node = node.child_by_field_name("module_name") + if module_node: + raw = _read_text(module_node, source) + if raw.startswith("."): + # Relative import - resolve to full path so IDs match file node IDs + dots = len(raw) - len(raw.lstrip(".")) + module_name = raw.lstrip(".") + base = Path(str_path).parent + for _ in range(dots - 1): + base = base.parent + rel = (module_name.replace(".", "/") + ".py") if module_name else "__init__.py" + tgt_nid = _make_id(str(base / rel)) + else: + tgt_nid = _make_id(raw) + edges.append({ + "source": file_nid, + "target": tgt_nid, + "relation": "imports_from", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + }) + + +def _import_js(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: + for child in node.children: + if child.type == "string": + raw = _read_text(child, source).strip("'\"` ") + if not raw: + break + if raw.startswith("."): + # Relative import - resolve to full path so IDs match file node IDs + resolved = Path(str_path).parent / raw + # TypeScript ESM: imports written as .js but actual file is .ts/.tsx + if resolved.suffix == ".js": + resolved = resolved.with_suffix(".ts") + elif resolved.suffix == ".jsx": + resolved = resolved.with_suffix(".tsx") + # Extensionless import: ./Foo → ./Foo.ts (or .tsx/.js/.jsx) + # Barrel import: ./components → ./components/index.ts (or .tsx/.js/.jsx) + if not resolved.suffix: + for ext in (".ts", ".tsx", ".js", ".jsx"): + candidate = resolved.with_suffix(ext) + if candidate.exists(): + resolved = candidate + break + else: + for ext in (".ts", ".tsx", ".js", ".jsx"): + candidate = resolved / ("index" + ext) + if candidate.exists(): + resolved = candidate + break + tgt_nid = _make_id(str(resolved)) + else: + # Bare/scoped import (node_modules) - use last segment; dropped as external + module_name = raw.split("/")[-1] + if not module_name: + break + tgt_nid = _make_id(module_name) + edges.append({ + "source": file_nid, + "target": tgt_nid, + "relation": "imports_from", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + }) + break + + +def _import_java(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: + def _walk_scoped(n) -> str: + parts: list[str] = [] + cur = n + while cur: + if cur.type == "scoped_identifier": + name_node = cur.child_by_field_name("name") + if name_node: + parts.append(_read_text(name_node, source)) + cur = cur.child_by_field_name("scope") + elif cur.type == "identifier": + parts.append(_read_text(cur, source)) + break + else: + break + parts.reverse() + return ".".join(parts) + + for child in node.children: + if child.type in ("scoped_identifier", "identifier"): + path_str = _walk_scoped(child) + module_name = path_str.split(".")[-1].strip("*").strip(".") or ( + path_str.split(".")[-2] if len(path_str.split(".")) > 1 else path_str + ) + if module_name: + tgt_nid = _make_id(module_name) + edges.append({ + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + }) + break + + +def _import_c(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: + for child in node.children: + if child.type in ("string_literal", "system_lib_string", "string"): + raw = _read_text(child, source).strip('"<> ') + module_name = raw.split("/")[-1].split(".")[0] + if module_name: + tgt_nid = _make_id(module_name) + edges.append({ + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + }) + break + + +def _import_csharp(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: + for child in node.children: + if child.type in ("qualified_name", "identifier", "name_equals"): + raw = _read_text(child, source) + module_name = raw.split(".")[-1].strip() + if module_name: + tgt_nid = _make_id(module_name) + edges.append({ + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + }) + break + + +def _import_kotlin(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: + path_node = node.child_by_field_name("path") + if path_node: + raw = _read_text(path_node, source) + module_name = raw.split(".")[-1].strip() + if module_name: + tgt_nid = _make_id(module_name) + edges.append({ + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + }) + return + # Fallback: find identifier child + for child in node.children: + if child.type == "identifier": + raw = _read_text(child, source) + tgt_nid = _make_id(raw) + edges.append({ + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + }) + break + + +def _import_scala(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: + for child in node.children: + if child.type in ("stable_id", "identifier"): + raw = _read_text(child, source) + module_name = raw.split(".")[-1].strip("{} ") + if module_name and module_name != "_": + tgt_nid = _make_id(module_name) + edges.append({ + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + }) + break + + +def _import_php(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: + for child in node.children: + if child.type in ("qualified_name", "name", "identifier"): + raw = _read_text(child, source) + module_name = raw.split("\\")[-1].strip() + if module_name: + tgt_nid = _make_id(module_name) + edges.append({ + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + }) + break + + +# ── C/C++ function name helpers ─────────────────────────────────────────────── + +def _get_c_func_name(node, source: bytes) -> str | None: + """Recursively unwrap declarator to find the innermost identifier (C).""" + if node.type == "identifier": + return _read_text(node, source) + decl = node.child_by_field_name("declarator") + if decl: + return _get_c_func_name(decl, source) + for child in node.children: + if child.type == "identifier": + return _read_text(child, source) + return None + + +def _get_cpp_func_name(node, source: bytes) -> str | None: + """Recursively unwrap declarator to find the innermost identifier (C++).""" + if node.type == "identifier": + return _read_text(node, source) + if node.type == "qualified_identifier": + name_node = node.child_by_field_name("name") + if name_node: + return _read_text(name_node, source) + decl = node.child_by_field_name("declarator") + if decl: + return _get_cpp_func_name(decl, source) + for child in node.children: + if child.type == "identifier": + return _read_text(child, source) + return None + + +# ── JS/TS extra walk for arrow functions ────────────────────────────────────── + +def _js_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, + nodes: list, edges: list, seen_ids: set, function_bodies: list, + parent_class_nid: str | None, add_node_fn, add_edge_fn) -> bool: + """Handle lexical_declaration (arrow functions) for JS/TS. Returns True if handled.""" + if node.type == "lexical_declaration": + for child in node.children: + if child.type == "variable_declarator": + value = child.child_by_field_name("value") + if value and value.type == "arrow_function": + name_node = child.child_by_field_name("name") + if name_node: + func_name = _read_text(name_node, source) + line = child.start_point[0] + 1 + func_nid = _make_id(stem, func_name) + add_node_fn(func_nid, f"{func_name}()", line) + add_edge_fn(file_nid, func_nid, "contains", line) + body = value.child_by_field_name("body") + if body: + function_bodies.append((func_nid, body)) + return True + return False + + +# ── C# extra walk for namespace declarations ────────────────────────────────── + +def _csharp_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, + nodes: list, edges: list, seen_ids: set, function_bodies: list, + parent_class_nid: str | None, add_node_fn, add_edge_fn, + walk_fn) -> bool: + """Handle namespace_declaration for C#. Returns True if handled.""" + if node.type == "namespace_declaration": + name_node = node.child_by_field_name("name") + if name_node: + ns_name = _read_text(name_node, source) + ns_nid = _make_id(stem, ns_name) + line = node.start_point[0] + 1 + add_node_fn(ns_nid, ns_name, line) + add_edge_fn(file_nid, ns_nid, "contains", line) + body = node.child_by_field_name("body") + if body: + for child in body.children: + walk_fn(child, parent_class_nid) + return True + return False + + +# ── Swift extra walk for enum cases ────────────────────────────────────────── + +def _swift_extra_walk(node, source: bytes, file_nid: str, stem: str, str_path: str, + nodes: list, edges: list, seen_ids: set, function_bodies: list, + parent_class_nid: str | None, add_node_fn, add_edge_fn) -> bool: + """Handle enum_entry for Swift. Returns True if handled.""" + if node.type == "enum_entry" and parent_class_nid: + for child in node.children: + if child.type == "simple_identifier": + case_name = _read_text(child, source) + case_nid = _make_id(parent_class_nid, case_name) + line = node.start_point[0] + 1 + add_node_fn(case_nid, case_name, line) + add_edge_fn(parent_class_nid, case_nid, "case_of", line) + return True + return False + + +# ── Language configs ────────────────────────────────────────────────────────── + +_PYTHON_CONFIG = LanguageConfig( + ts_module="tree_sitter_python", + class_types=frozenset({"class_definition"}), + function_types=frozenset({"function_definition"}), + import_types=frozenset({"import_statement", "import_from_statement"}), + call_types=frozenset({"call"}), + call_function_field="function", + call_accessor_node_types=frozenset({"attribute"}), + call_accessor_field="attribute", + function_boundary_types=frozenset({"function_definition"}), + import_handler=_import_python, +) + +_JS_CONFIG = LanguageConfig( + ts_module="tree_sitter_javascript", + class_types=frozenset({"class_declaration"}), + function_types=frozenset({"function_declaration", "method_definition"}), + import_types=frozenset({"import_statement"}), + reexport_types=frozenset({"export_statement"}), + call_types=frozenset({"call_expression"}), + call_function_field="function", + call_accessor_node_types=frozenset({"member_expression"}), + call_accessor_field="property", + function_boundary_types=frozenset({"function_declaration", "arrow_function", "method_definition"}), + import_handler=_import_js, +) + +_TS_CONFIG = LanguageConfig( + ts_module="tree_sitter_typescript", + ts_language_fn="language_typescript", + class_types=frozenset({"class_declaration"}), + function_types=frozenset({"function_declaration", "method_definition"}), + import_types=frozenset({"import_statement"}), + reexport_types=frozenset({"export_statement"}), + call_types=frozenset({"call_expression"}), + call_function_field="function", + call_accessor_node_types=frozenset({"member_expression"}), + call_accessor_field="property", + function_boundary_types=frozenset({"function_declaration", "arrow_function", "method_definition"}), + import_handler=_import_js, +) + +_JAVA_CONFIG = LanguageConfig( + ts_module="tree_sitter_java", + class_types=frozenset({"class_declaration", "interface_declaration"}), + function_types=frozenset({"method_declaration", "constructor_declaration"}), + import_types=frozenset({"import_declaration"}), + call_types=frozenset({"method_invocation"}), + call_function_field="name", + call_accessor_node_types=frozenset(), + function_boundary_types=frozenset({"method_declaration", "constructor_declaration"}), + import_handler=_import_java, +) + +_C_CONFIG = LanguageConfig( + ts_module="tree_sitter_c", + class_types=frozenset(), + function_types=frozenset({"function_definition"}), + import_types=frozenset({"preproc_include"}), + call_types=frozenset({"call_expression"}), + call_function_field="function", + call_accessor_node_types=frozenset({"field_expression"}), + call_accessor_field="field", + function_boundary_types=frozenset({"function_definition"}), + import_handler=_import_c, + resolve_function_name_fn=_get_c_func_name, +) + +_CPP_CONFIG = LanguageConfig( + ts_module="tree_sitter_cpp", + class_types=frozenset({"class_specifier"}), + function_types=frozenset({"function_definition"}), + import_types=frozenset({"preproc_include"}), + call_types=frozenset({"call_expression"}), + call_function_field="function", + call_accessor_node_types=frozenset({"field_expression", "qualified_identifier"}), + call_accessor_field="field", + function_boundary_types=frozenset({"function_definition"}), + import_handler=_import_c, + resolve_function_name_fn=_get_cpp_func_name, +) + +_RUBY_CONFIG = LanguageConfig( + ts_module="tree_sitter_ruby", + class_types=frozenset({"class"}), + function_types=frozenset({"method", "singleton_method"}), + import_types=frozenset(), + call_types=frozenset({"call"}), + call_function_field="method", + call_accessor_node_types=frozenset(), + name_fallback_child_types=("constant", "scope_resolution", "identifier"), + body_fallback_child_types=("body_statement",), + function_boundary_types=frozenset({"method", "singleton_method"}), +) + +_CSHARP_CONFIG = LanguageConfig( + ts_module="tree_sitter_c_sharp", + class_types=frozenset({"class_declaration", "interface_declaration"}), + function_types=frozenset({"method_declaration"}), + import_types=frozenset({"using_directive"}), + call_types=frozenset({"invocation_expression"}), + call_function_field="function", + call_accessor_node_types=frozenset({"member_access_expression"}), + call_accessor_field="name", + body_fallback_child_types=("declaration_list",), + function_boundary_types=frozenset({"method_declaration"}), + import_handler=_import_csharp, +) + +_KOTLIN_CONFIG = LanguageConfig( + ts_module="tree_sitter_kotlin", + class_types=frozenset({"class_declaration", "object_declaration"}), + function_types=frozenset({"function_declaration"}), + import_types=frozenset({"import_header"}), + call_types=frozenset({"call_expression"}), + call_function_field="", + call_accessor_node_types=frozenset({"navigation_expression"}), + call_accessor_field="", + name_fallback_child_types=("simple_identifier",), + body_fallback_child_types=("function_body", "class_body"), + function_boundary_types=frozenset({"function_declaration"}), + import_handler=_import_kotlin, +) + +_SCALA_CONFIG = LanguageConfig( + ts_module="tree_sitter_scala", + class_types=frozenset({"class_definition", "object_definition"}), + function_types=frozenset({"function_definition"}), + import_types=frozenset({"import_declaration"}), + call_types=frozenset({"call_expression"}), + call_function_field="", + call_accessor_node_types=frozenset({"field_expression"}), + call_accessor_field="field", + name_fallback_child_types=("identifier",), + body_fallback_child_types=("template_body",), + function_boundary_types=frozenset({"function_definition"}), + import_handler=_import_scala, +) + +_PHP_CONFIG = LanguageConfig( + ts_module="tree_sitter_php", + ts_language_fn="language_php", + class_types=frozenset({"class_declaration"}), + function_types=frozenset({"function_definition", "method_declaration"}), + import_types=frozenset({"namespace_use_clause"}), + call_types=frozenset({"function_call_expression", "member_call_expression", "scoped_call_expression", "class_constant_access_expression"}), + static_prop_types=frozenset({"scoped_property_access_expression"}), + helper_fn_names=frozenset({"config"}), + container_bind_methods=frozenset({"bind", "singleton", "scoped", "instance"}), + event_listener_properties=frozenset({"listen", "subscribe"}), + call_function_field="function", + call_accessor_node_types=frozenset({"member_call_expression"}), + call_accessor_field="name", + name_fallback_child_types=("name",), + body_fallback_child_types=("declaration_list", "compound_statement"), + function_boundary_types=frozenset({"function_definition", "method_declaration"}), + import_handler=_import_php, +) + + +def _import_lua(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: + """Extract require('module') from Lua variable_declaration nodes.""" + text = _read_text(node, source) + import re + m = re.search(r"""require\s*[\('"]\s*['"]?([^'")\s]+)""", text) + if m: + module_name = m.group(1).split(".")[-1] + if module_name: + edges.append({ + "source": file_nid, + "target": module_name, + "relation": "imports", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": str(node.start_point[0] + 1), + "weight": 1.0, + }) + + +_LUA_CONFIG = LanguageConfig( + ts_module="tree_sitter_lua", + ts_language_fn="language", + class_types=frozenset(), + function_types=frozenset({"function_declaration"}), + import_types=frozenset({"variable_declaration"}), + call_types=frozenset({"function_call"}), + call_function_field="name", + call_accessor_node_types=frozenset({"method_index_expression"}), + call_accessor_field="name", + name_fallback_child_types=("identifier", "method_index_expression"), + body_fallback_child_types=("block",), + function_boundary_types=frozenset({"function_declaration"}), + import_handler=_import_lua, +) + + +def _import_swift(node, source: bytes, file_nid: str, stem: str, edges: list, str_path: str) -> None: + for child in node.children: + if child.type == "identifier": + raw = _read_text(child, source) + tgt_nid = _make_id(raw) + edges.append({ + "source": file_nid, + "target": tgt_nid, + "relation": "imports", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + "weight": 1.0, + }) + break + + +_SWIFT_CONFIG = LanguageConfig( + ts_module="tree_sitter_swift", + class_types=frozenset({"class_declaration", "protocol_declaration"}), + function_types=frozenset({"function_declaration", "init_declaration", "deinit_declaration", "subscript_declaration"}), + import_types=frozenset({"import_declaration"}), + call_types=frozenset({"call_expression"}), + call_function_field="", + call_accessor_node_types=frozenset({"navigation_expression"}), + call_accessor_field="", + name_fallback_child_types=("simple_identifier", "type_identifier", "user_type"), + body_fallback_child_types=("class_body", "protocol_body", "function_body", "enum_class_body"), + function_boundary_types=frozenset({"function_declaration", "init_declaration", "deinit_declaration", "subscript_declaration"}), + import_handler=_import_swift, +) + + +# ── Generic extractor ───────────────────────────────────────────────────────── + +def _extract_generic(path: Path, config: LanguageConfig) -> dict: + """Generic AST extractor driven by LanguageConfig.""" try: - import tree_sitter_python as tspython + mod = importlib.import_module(config.ts_module) from tree_sitter import Language, Parser + lang_fn = getattr(mod, config.ts_language_fn, None) + if lang_fn is None: + # Fallback for PHP: try "language_php" then "language" + lang_fn = getattr(mod, "language", None) + if lang_fn is None: + return {"nodes": [], "edges": [], "error": f"No language function in {config.ts_module}"} + language = Language(lang_fn()) except ImportError: - return {"nodes": [], "edges": [], "error": "tree-sitter-python not installed"} + return {"nodes": [], "edges": [], "error": f"{config.ts_module} not installed"} + except Exception as e: + return {"nodes": [], "edges": [], "error": str(e)} try: - language = Language(tspython.language()) parser = Parser(language) source = path.read_bytes() tree = parser.parse(source) @@ -36,6 +697,8 @@ def extract_python(path: Path) -> dict: nodes: list[dict] = [] edges: list[dict] = [] seen_ids: set[str] = set() + function_bodies: list[tuple[str, object]] = [] + pending_listen_edges: list[tuple[str, str, int]] = [] def add_node(nid: str, label: str, line: int) -> None: if nid not in seen_ids: @@ -48,86 +711,205 @@ def add_node(nid: str, label: str, line: int) -> None: "source_location": f"L{line}", }) - def add_edge(src: str, tgt: str, relation: str, line: int) -> None: - # Only add edge if both endpoints exist or src is the file node + def add_edge(src: str, tgt: str, relation: str, line: int, + confidence: str = "EXTRACTED", weight: float = 1.0) -> None: edges.append({ "source": src, "target": tgt, "relation": relation, - "confidence": "EXTRACTED", + "confidence": confidence, "source_file": str_path, "source_location": f"L{line}", - "weight": 1.0, + "weight": weight, }) - # File-level node - stable ID based on stem only - file_nid = _make_id(stem) + file_nid = _make_id(str(path)) add_node(file_nid, path.name, 1) def walk(node, parent_class_nid: str | None = None) -> None: t = node.type - if t == "import_statement": - for child in node.children: - if child.type in ("dotted_name", "aliased_import"): - raw = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") - module_name = raw.split(" as ")[0].strip().lstrip(".") - tgt_nid = _make_id(module_name) - add_edge(file_nid, tgt_nid, "imports", node.start_point[0] + 1) + # Import types + if t in config.import_types: + if config.import_handler: + config.import_handler(node, source, file_nid, stem, edges, str_path) return - if t == "import_from_statement": - module_node = node.child_by_field_name("module_name") - if module_node: - raw = source[module_node.start_byte:module_node.end_byte].decode("utf-8", errors="replace").lstrip(".") - tgt_nid = _make_id(raw) - add_edge(file_nid, tgt_nid, "imports_from", node.start_point[0] + 1) - return + # Re-export types (barrel files): export * from '...' / export { X } from '...' + if config.reexport_types and t in config.reexport_types: + has_from = any(c.type == "from" for c in node.children) + if has_from and config.import_handler: + config.import_handler(node, source, file_nid, stem, edges, str_path) + return + # Regular exports (no 'from') fall through to walk children normally - if t == "class_definition": - name_node = node.child_by_field_name("name") + # Class types + if t in config.class_types: + # Resolve class name + name_node = node.child_by_field_name(config.name_field) + if name_node is None: + for child in node.children: + if child.type in config.name_fallback_child_types: + name_node = child + break if not name_node: return - class_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") + class_name = _read_text(name_node, source) class_nid = _make_id(stem, class_name) line = node.start_point[0] + 1 add_node(class_nid, class_name, line) add_edge(file_nid, class_nid, "contains", line) - # Inheritance - create stub node for external bases so the edge is never dropped - args = node.child_by_field_name("superclasses") - if args: - for arg in args.children: - if arg.type == "identifier": - base = source[arg.start_byte:arg.end_byte].decode("utf-8", errors="replace") - # Try same-file base first; fall back to a bare stub - base_nid = _make_id(stem, base) - if base_nid not in seen_ids: - # External or forward-declared base - add a stub so edge survives - base_nid = _make_id(base) + # Python-specific: inheritance + if config.ts_module == "tree_sitter_python": + args = node.child_by_field_name("superclasses") + if args: + for arg in args.children: + if arg.type == "identifier": + base = _read_text(arg, source) + base_nid = _make_id(stem, base) if base_nid not in seen_ids: - nodes.append({ - "id": base_nid, - "label": base, - "file_type": "code", - "source_file": "", - "source_location": "", - }) - seen_ids.add(base_nid) - add_edge(class_nid, base_nid, "inherits", line) - - # Walk class body for methods - body = node.child_by_field_name("body") + base_nid = _make_id(base) + if base_nid not in seen_ids: + nodes.append({ + "id": base_nid, + "label": base, + "file_type": "code", + "source_file": "", + "source_location": "", + }) + seen_ids.add(base_nid) + add_edge(class_nid, base_nid, "inherits", line) + + # Swift-specific: conformance / inheritance + if config.ts_module == "tree_sitter_swift": + for child in node.children: + if child.type == "inheritance_specifier": + for sub in child.children: + if sub.type in ("user_type", "type_identifier"): + base = _read_text(sub, source) + base_nid = _make_id(stem, base) + if base_nid not in seen_ids: + base_nid = _make_id(base) + if base_nid not in seen_ids: + nodes.append({ + "id": base_nid, + "label": base, + "file_type": "code", + "source_file": "", + "source_location": "", + }) + seen_ids.add(base_nid) + add_edge(class_nid, base_nid, "inherits", line) + + # C#-specific: inheritance / interface implementation via base_list + if config.ts_module == "tree_sitter_c_sharp": + for child in node.children: + if child.type == "base_list": + for sub in child.children: + if sub.type in ("identifier", "generic_name"): + if sub.type == "generic_name": + name_child = sub.child_by_field_name("name") + base = _read_text(name_child, source) if name_child else _read_text(sub.children[0], source) + else: + base = _read_text(sub, source) + base_nid = _make_id(stem, base) + if base_nid not in seen_ids: + base_nid = _make_id(base) + if base_nid not in seen_ids: + nodes.append({ + "id": base_nid, + "label": base, + "file_type": "code", + "source_file": "", + "source_location": "", + }) + seen_ids.add(base_nid) + add_edge(class_nid, base_nid, "inherits", line) + + # Find body and recurse + body = _find_body(node, config) if body: for child in body.children: walk(child, parent_class_nid=class_nid) return - if t == "function_definition": - name_node = node.child_by_field_name("name") - if not name_node: + # Event listener property arrays: $listen = [Event::class => [Listener::class]] + if (t == "property_declaration" + and parent_class_nid + and config.event_listener_properties): + for element in node.children: + if element.type != "property_element": + continue + prop_name: str | None = None + array_node = None + for c in element.children: + if c.type == "variable_name": + for sc in c.children: + if sc.type == "name": + prop_name = _read_text(sc, source) + break + elif c.type == "array_creation_expression": + array_node = c + if (prop_name is None + or prop_name not in config.event_listener_properties + or array_node is None): + continue + for entry in array_node.children: + if entry.type != "array_element_initializer": + continue + event_cls: str | None = None + listener_arr = None + for sub in entry.children: + if sub.type == "class_constant_access_expression" and event_cls is None: + for sc in sub.children: + if sc.is_named and sc.type in ("name", "qualified_name"): + event_cls = _read_text(sc, source) + break + elif sub.type == "array_creation_expression": + listener_arr = sub + if not event_cls or listener_arr is None: + continue + for listener_entry in listener_arr.children: + if listener_entry.type != "array_element_initializer": + continue + for item in listener_entry.children: + if item.type != "class_constant_access_expression": + continue + for sc in item.children: + if sc.is_named and sc.type in ("name", "qualified_name"): + listener_cls = _read_text(sc, source) + line_no = item.start_point[0] + 1 + pending_listen_edges.append((event_cls, listener_cls, line_no)) + break + break + return + + # Function types + if t in config.function_types: + # Swift deinit/subscript have no name field — resolve before generic fallback + if t == "deinit_declaration": + func_name: str | None = "deinit" + elif t == "subscript_declaration": + func_name = "subscript" + elif config.resolve_function_name_fn is not None: + # C/C++ style: use declarator + declarator = node.child_by_field_name("declarator") + func_name = None + if declarator: + func_name = config.resolve_function_name_fn(declarator, source) + else: + name_node = node.child_by_field_name(config.name_field) + if name_node is None: + for child in node.children: + if child.type in config.name_fallback_child_types: + name_node = child + break + func_name = _read_text(name_node, source) if name_node else None + + if not func_name: return - func_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") + line = node.start_point[0] + 1 if parent_class_nid: func_nid = _make_id(parent_class_nid, func_name) @@ -137,22 +919,38 @@ def walk(node, parent_class_nid: str | None = None) -> None: func_nid = _make_id(stem, func_name) add_node(func_nid, f"{func_name}()", line) add_edge(file_nid, func_nid, "contains", line) - # Collect body for the call-graph pass below - body = node.child_by_field_name("body") + + body = _find_body(node, config) if body: function_bodies.append((func_nid, body)) return + # JS/TS arrow functions and C# namespaces — language-specific extra handling + if config.ts_module in ("tree_sitter_javascript", "tree_sitter_typescript"): + if _js_extra_walk(node, source, file_nid, stem, str_path, + nodes, edges, seen_ids, function_bodies, + parent_class_nid, add_node, add_edge): + return + + if config.ts_module == "tree_sitter_c_sharp": + if _csharp_extra_walk(node, source, file_nid, stem, str_path, + nodes, edges, seen_ids, function_bodies, + parent_class_nid, add_node, add_edge, walk): + return + + if config.ts_module == "tree_sitter_swift": + if _swift_extra_walk(node, source, file_nid, stem, str_path, + nodes, edges, seen_ids, function_bodies, + parent_class_nid, add_node, add_edge): + return + + # Default: recurse for child in node.children: walk(child, parent_class_nid=None) - function_bodies: list[tuple[str, object]] = [] walk(root) # ── Call-graph pass ─────────────────────────────────────────────────────── - # Build label→nid lookup from all nodes collected above. - # Normalise: strip "()" suffix and leading "." so "cohesion_score()" and - # ".cohesion_score()" both map to the same entry. label_to_nid: dict[str, str] = {} for n in nodes: raw = n["label"] @@ -160,21 +958,122 @@ def walk(node, parent_class_nid: str | None = None) -> None: label_to_nid[normalised.lower()] = n["id"] seen_call_pairs: set[tuple[str, str]] = set() + seen_static_ref_pairs: set[tuple[str, str, str]] = set() + seen_helper_ref_pairs: set[tuple[str, str, str]] = set() + seen_bind_pairs: set[tuple[str, str, str]] = set() + raw_calls: list[dict] = [] # unresolved calls for cross-file resolution in extract() + + def _php_class_const_scope(n) -> str | None: + scope = n.child_by_field_name("scope") + if scope is None: + for c in n.children: + if c.is_named and c.type in ("name", "qualified_name", "identifier"): + scope = c + break + if scope is None: + return None + return _read_text(scope, source) def walk_calls(node, caller_nid: str) -> None: - # Don't recurse into nested function definitions - they have their own context. - if node.type == "function_definition": + if node.type in config.function_boundary_types: return - if node.type == "call": - func_node = node.child_by_field_name("function") + + if node.type in config.call_types: callee_name: str | None = None - if func_node: - if func_node.type == "identifier": - callee_name = source[func_node.start_byte:func_node.end_byte].decode("utf-8", errors="replace") - elif func_node.type == "attribute": - attr = func_node.child_by_field_name("attribute") - if attr: - callee_name = source[attr.start_byte:attr.end_byte].decode("utf-8", errors="replace") + + # Special handling per language + if config.ts_module == "tree_sitter_swift": + # Swift: first child may be simple_identifier or navigation_expression + first = node.children[0] if node.children else None + if first: + if first.type == "simple_identifier": + callee_name = _read_text(first, source) + elif first.type == "navigation_expression": + for child in first.children: + if child.type == "navigation_suffix": + for sc in child.children: + if sc.type == "simple_identifier": + callee_name = _read_text(sc, source) + elif config.ts_module == "tree_sitter_kotlin": + # Kotlin: first child may be simple_identifier or navigation_expression + first = node.children[0] if node.children else None + if first: + if first.type == "simple_identifier": + callee_name = _read_text(first, source) + elif first.type == "navigation_expression": + for child in reversed(first.children): + if child.type == "simple_identifier": + callee_name = _read_text(child, source) + break + elif config.ts_module == "tree_sitter_scala": + # Scala: first child + first = node.children[0] if node.children else None + if first: + if first.type == "identifier": + callee_name = _read_text(first, source) + elif first.type == "field_expression": + field = first.child_by_field_name("field") + if field: + callee_name = _read_text(field, source) + else: + for child in reversed(first.children): + if child.type == "identifier": + callee_name = _read_text(child, source) + break + elif config.ts_module == "tree_sitter_c_sharp" and node.type == "invocation_expression": + # C#: try name field, then first named child + name_node = node.child_by_field_name("name") + if name_node: + callee_name = _read_text(name_node, source) + else: + for child in node.children: + if child.is_named: + raw = _read_text(child, source) + if "." in raw: + callee_name = raw.split(".")[-1] + else: + callee_name = raw + break + elif config.ts_module == "tree_sitter_php": + # PHP: distinguish call expression subtypes + if node.type == "function_call_expression": + func_node = node.child_by_field_name("function") + if func_node: + callee_name = _read_text(func_node, source) + elif node.type == "scoped_call_expression": + # Static method call: Helper::format() → callee = "Helper" + scope_node = node.child_by_field_name("scope") + if scope_node: + callee_name = _read_text(scope_node, source) + else: + name_node = node.child_by_field_name("name") + if name_node: + callee_name = _read_text(name_node, source) + elif config.ts_module == "tree_sitter_cpp": + # C++: function field, then field_expression/qualified_identifier + func_node = node.child_by_field_name(config.call_function_field) if config.call_function_field else None + if func_node: + if func_node.type == "identifier": + callee_name = _read_text(func_node, source) + elif func_node.type in ("field_expression", "qualified_identifier"): + name = func_node.child_by_field_name("field") or func_node.child_by_field_name("name") + if name: + callee_name = _read_text(name, source) + else: + # Generic: get callee from call_function_field + func_node = node.child_by_field_name(config.call_function_field) if config.call_function_field else None + if func_node: + if func_node.type == "identifier": + callee_name = _read_text(func_node, source) + elif func_node.type in config.call_accessor_node_types: + if config.call_accessor_field: + attr = func_node.child_by_field_name(config.call_accessor_field) + if attr: + callee_name = _read_text(attr, source) + else: + # Try reading the node directly (e.g. Java name field is the callee) + callee_name = _read_text(func_node, source) + if callee_name: tgt_nid = label_to_nid.get(callee_name.lower()) if tgt_nid and tgt_nid != caller_nid: @@ -186,219 +1085,773 @@ def walk_calls(node, caller_nid: str) -> None: "source": caller_nid, "target": tgt_nid, "relation": "calls", - "confidence": "INFERRED", + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + }) + elif callee_name and not tgt_nid: + # Callee not in this file — save for cross-file resolution in extract() + raw_calls.append({ + "caller_nid": caller_nid, + "callee": callee_name, + "source_file": str_path, + "source_location": f"L{node.start_point[0] + 1}", + }) + + # Helper function calls: config('foo.bar') → uses_config edge to "foo" + if (callee_name and callee_name in config.helper_fn_names): + args_node = node.child_by_field_name("arguments") + first_key: str | None = None + if args_node: + for arg in args_node.children: + if arg.type != "argument": + continue + for inner in arg.children: + if inner.type == "string": + for sc in inner.children: + if sc.type == "string_content": + first_key = _read_text(sc, source) + break + break + if first_key: + break + if first_key: + segment = first_key.split(".")[0] + tgt_nid = (label_to_nid.get(segment.lower()) + or label_to_nid.get(f"{segment}.php".lower())) + if tgt_nid and tgt_nid != caller_nid: + relation = f"uses_{callee_name}" + pair3 = (caller_nid, tgt_nid, relation) + if pair3 not in seen_helper_ref_pairs: + seen_helper_ref_pairs.add(pair3) + line = node.start_point[0] + 1 + edges.append({ + "source": caller_nid, + "target": tgt_nid, + "relation": relation, + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + }) + + # Service container bindings: $this->app->bind(Foo::class, Bar::class) + if (node.type == "member_call_expression" + and callee_name + and callee_name in config.container_bind_methods): + args_node = node.child_by_field_name("arguments") + class_args: list[str] = [] + if args_node: + for arg in args_node.children: + if arg.type != "argument": + continue + for inner in arg.children: + if inner.type == "class_constant_access_expression": + cls = _php_class_const_scope(inner) + if cls: + class_args.append(cls) + break + if len(class_args) >= 2: + break + if len(class_args) == 2: + contract_name, impl_name = class_args + contract_nid = label_to_nid.get(contract_name.lower()) + impl_nid = label_to_nid.get(impl_name.lower()) + if contract_nid and impl_nid and contract_nid != impl_nid: + pair3 = (contract_nid, impl_nid, "bound_to") + if pair3 not in seen_bind_pairs: + seen_bind_pairs.add(pair3) + line = node.start_point[0] + 1 + edges.append({ + "source": contract_nid, + "target": impl_nid, + "relation": "bound_to", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + }) + + # Static property access: Foo::$bar → uses_static_prop edge + if node.type in config.static_prop_types: + scope_node = node.child_by_field_name("scope") + if scope_node is None: + for child in node.children: + if child.is_named and child.type in ("name", "qualified_name", "identifier"): + scope_node = child + break + if scope_node is not None: + class_name = _read_text(scope_node, source) + tgt_nid = label_to_nid.get(class_name.lower()) + if tgt_nid and tgt_nid != caller_nid: + pair3 = (caller_nid, tgt_nid, "uses_static_prop") + if pair3 not in seen_static_ref_pairs: + seen_static_ref_pairs.add(pair3) + line = node.start_point[0] + 1 + edges.append({ + "source": caller_nid, + "target": tgt_nid, + "relation": "uses_static_prop", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + }) + + # PHP class constant access: Foo::BAR → references_constant edge + if config.ts_module == "tree_sitter_php" and node.type == "class_constant_access_expression": + class_name = _php_class_const_scope(node) + if class_name: + tgt_nid = label_to_nid.get(class_name.lower()) + if tgt_nid and tgt_nid != caller_nid: + pair3 = (caller_nid, tgt_nid, "references_constant") + if pair3 not in seen_static_ref_pairs: + seen_static_ref_pairs.add(pair3) + line = node.start_point[0] + 1 + edges.append({ + "source": caller_nid, + "target": tgt_nid, + "relation": "references_constant", + "confidence": "EXTRACTED", + "confidence_score": 1.0, "source_file": str_path, "source_location": f"L{line}", - "weight": 0.8, + "weight": 1.0, }) + for child in node.children: walk_calls(child, caller_nid) for caller_nid, body_node in function_bodies: walk_calls(body_node, caller_nid) - # ───────────────────────────────────────────────────────────────────────── - # Post-process: remove edges whose source or target was never added as a node - # (dangling import edges pointing to external libraries are fine to keep, - # but edges between internal entities must be valid) + # ── Event listener pass ─────────────────────────────────────────────────── + seen_listen_pairs: set[tuple[str, str]] = set() + for event_name, listener_name, line in pending_listen_edges: + event_nid = label_to_nid.get(event_name.lower()) + listener_nid = label_to_nid.get(listener_name.lower()) + if not event_nid or not listener_nid or event_nid == listener_nid: + continue + pair2 = (event_nid, listener_nid) + if pair2 in seen_listen_pairs: + continue + seen_listen_pairs.add(pair2) + edges.append({ + "source": event_nid, + "target": listener_nid, + "relation": "listened_by", + "confidence": "EXTRACTED", + "confidence_score": 1.0, + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + }) + + # ── Clean edges ─────────────────────────────────────────────────────────── valid_ids = seen_ids clean_edges = [] for edge in edges: src, tgt = edge["source"], edge["target"] - # Keep if both endpoints are known, OR if it's an import edge (tgt may be external) if src in valid_ids and (tgt in valid_ids or edge["relation"] in ("imports", "imports_from")): clean_edges.append(edge) - return {"nodes": nodes, "edges": clean_edges} + return {"nodes": nodes, "edges": clean_edges, "raw_calls": raw_calls} -def extract_js(path: Path) -> dict: - """Extract classes, functions, arrow functions, and imports from a .js/.ts/.tsx file.""" - try: - if path.suffix in (".ts", ".tsx"): - import tree_sitter_typescript as tslang - from tree_sitter import Language, Parser - language = Language(tslang.language_typescript()) - else: - import tree_sitter_javascript as tslang - from tree_sitter import Language, Parser - language = Language(tslang.language()) - except ImportError: - return {"nodes": [], "edges": [], "error": "tree-sitter-javascript/typescript not installed"} +# ── Python rationale extraction ─────────────────────────────────────────────── + +_RATIONALE_PREFIXES = ("# NOTE:", "# IMPORTANT:", "# HACK:", "# WHY:", "# RATIONALE:", "# TODO:", "# FIXME:") + +def _extract_python_rationale(path: Path, result: dict) -> None: + """Post-pass: extract docstrings and rationale comments from Python source. + Mutates result in-place by appending to result['nodes'] and result['edges']. + """ try: + import tree_sitter_python as tspython + from tree_sitter import Language, Parser + language = Language(tspython.language()) parser = Parser(language) source = path.read_bytes() tree = parser.parse(source) root = tree.root_node - except Exception as e: - return {"nodes": [], "edges": [], "error": str(e)} + except Exception: + return stem = path.stem str_path = str(path) - nodes: list[dict] = [] - edges: list[dict] = [] - seen_ids: set[str] = set() + nodes = result["nodes"] + edges = result["edges"] + seen_ids = {n["id"] for n in nodes} + file_nid = _make_id(str(path)) + + def _get_docstring(body_node) -> tuple[str, int] | None: + if not body_node: + return None + for child in body_node.children: + if child.type == "expression_statement": + for sub in child.children: + if sub.type in ("string", "concatenated_string"): + text = source[sub.start_byte:sub.end_byte].decode("utf-8", errors="replace") + text = text.strip("\"'").strip('"""').strip("'''").strip() + if len(text) > 20: + return text, child.start_point[0] + 1 + break + return None - def add_node(nid: str, label: str, line: int) -> None: - if nid not in seen_ids: - seen_ids.add(nid) + def _add_rationale(text: str, line: int, parent_nid: str) -> None: + label = text[:80].replace("\r\n", " ").replace("\r", " ").replace("\n", " ").strip() + rid = _make_id(stem, "rationale", str(line)) + if rid not in seen_ids: + seen_ids.add(rid) nodes.append({ - "id": nid, + "id": rid, "label": label, - "file_type": "code", + "file_type": "rationale", "source_file": str_path, "source_location": f"L{line}", }) - - def add_edge(src: str, tgt: str, relation: str, line: int, confidence: str = "EXTRACTED", weight: float = 1.0) -> None: edges.append({ - "source": src, - "target": tgt, - "relation": relation, - "confidence": confidence, + "source": rid, + "target": parent_nid, + "relation": "rationale_for", + "confidence": "EXTRACTED", "source_file": str_path, "source_location": f"L{line}", - "weight": weight, + "weight": 1.0, }) - file_nid = _make_id(stem) - add_node(file_nid, path.name, 1) - - function_bodies: list[tuple[str, object]] = [] + # Module-level docstring + ds = _get_docstring(root) + if ds: + _add_rationale(ds[0], ds[1], file_nid) - def walk(node, parent_class_nid: str | None = None) -> None: + # Class and function docstrings + def walk_docstrings(node, parent_nid: str) -> None: t = node.type - - if t == "import_statement": - for child in node.children: - if child.type == "string": - raw = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace").strip("'\"` ") - module_name = raw.lstrip("./").split("/")[-1] - if module_name: - tgt_nid = _make_id(module_name) - add_edge(file_nid, tgt_nid, "imports_from", node.start_point[0] + 1) - return - - if t == "class_declaration": + if t == "class_definition": name_node = node.child_by_field_name("name") - if not name_node: - return - class_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - class_nid = _make_id(stem, class_name) - line = node.start_point[0] + 1 - add_node(class_nid, class_name, line) - add_edge(file_nid, class_nid, "contains", line) body = node.child_by_field_name("body") - if body: + if name_node and body: + class_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") + nid = _make_id(stem, class_name) + ds = _get_docstring(body) + if ds: + _add_rationale(ds[0], ds[1], nid) for child in body.children: - walk(child, parent_class_nid=class_nid) + walk_docstrings(child, nid) return - - if t == "function_declaration": + if t == "function_definition": name_node = node.child_by_field_name("name") - if not name_node: - return - func_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - line = node.start_point[0] + 1 - func_nid = _make_id(stem, func_name) - add_node(func_nid, f"{func_name}()", line) - add_edge(file_nid, func_nid, "contains", line) body = node.child_by_field_name("body") - if body: - function_bodies.append((func_nid, body)) + if name_node and body: + func_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") + nid = _make_id(parent_nid, func_name) if parent_nid != file_nid else _make_id(stem, func_name) + ds = _get_docstring(body) + if ds: + _add_rationale(ds[0], ds[1], nid) return + for child in node.children: + walk_docstrings(child, parent_nid) - if t == "method_definition" and parent_class_nid: - name_node = node.child_by_field_name("name") - if not name_node: - return - method_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - line = node.start_point[0] + 1 - method_nid = _make_id(parent_class_nid, method_name) - add_node(method_nid, f".{method_name}()", line) - add_edge(parent_class_nid, method_nid, "method", line) - body = node.child_by_field_name("body") - if body: - function_bodies.append((method_nid, body)) - return + walk_docstrings(root, file_nid) - if t == "lexical_declaration": - # Arrow functions: const foo = (...) => { ... } - for child in node.children: - if child.type == "variable_declarator": - value = child.child_by_field_name("value") - if value and value.type == "arrow_function": - name_node = child.child_by_field_name("name") - if name_node: - func_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - line = child.start_point[0] + 1 - func_nid = _make_id(stem, func_name) - add_node(func_nid, f"{func_name}()", line) - add_edge(file_nid, func_nid, "contains", line) - body = value.child_by_field_name("body") - if body: - function_bodies.append((func_nid, body)) - return + # Rationale comments (# NOTE:, # IMPORTANT:, etc.) + source_text = source.decode("utf-8", errors="replace") + for lineno, line_text in enumerate(source_text.splitlines(), start=1): + stripped = line_text.strip() + if any(stripped.startswith(p) for p in _RATIONALE_PREFIXES): + _add_rationale(stripped, lineno, file_nid) + + +# ── Public API ──────────────────────────────────────────────────────────────── + +def extract_python(path: Path) -> dict: + """Extract classes, functions, and imports from a .py file via tree-sitter AST.""" + result = _extract_generic(path, _PYTHON_CONFIG) + if "error" not in result: + _extract_python_rationale(path, result) + return result + + +def extract_js(path: Path) -> dict: + """Extract classes, functions, arrow functions, and imports from a .js/.ts/.tsx file.""" + config = _TS_CONFIG if path.suffix in (".ts", ".tsx") else _JS_CONFIG + return _extract_generic(path, config) + + +def extract_java(path: Path) -> dict: + """Extract classes, interfaces, methods, constructors, and imports from a .java file.""" + return _extract_generic(path, _JAVA_CONFIG) + + +def extract_c(path: Path) -> dict: + """Extract functions and includes from a .c/.h file.""" + return _extract_generic(path, _C_CONFIG) + + +def extract_cpp(path: Path) -> dict: + """Extract functions, classes, and includes from a .cpp/.cc/.cxx/.hpp file.""" + return _extract_generic(path, _CPP_CONFIG) + + +def extract_ruby(path: Path) -> dict: + """Extract classes, methods, singleton methods, and calls from a .rb file.""" + return _extract_generic(path, _RUBY_CONFIG) + + +def extract_csharp(path: Path) -> dict: + """Extract classes, interfaces, methods, namespaces, and usings from a .cs file.""" + return _extract_generic(path, _CSHARP_CONFIG) + + +def extract_kotlin(path: Path) -> dict: + """Extract classes, objects, functions, and imports from a .kt/.kts file.""" + return _extract_generic(path, _KOTLIN_CONFIG) + + +def extract_scala(path: Path) -> dict: + """Extract classes, objects, functions, and imports from a .scala file.""" + return _extract_generic(path, _SCALA_CONFIG) + + +def extract_php(path: Path) -> dict: + """Extract classes, functions, methods, namespace uses, and calls from a .php file.""" + return _extract_generic(path, _PHP_CONFIG) + + +def extract_blade(path: Path) -> dict: + """Extract @include, components, and wire:click bindings from Blade templates.""" + import re + try: + src = path.read_text(encoding="utf-8", errors="replace") + except OSError: + return {"error": f"cannot read {path}"} + + file_nid = _make_id(str(path)) + nodes = [{"id": file_nid, "label": path.name, "file_type": "code", + "source_file": str(path), "source_location": None}] + edges = [] + + # @include('path.to.partial') or @include("path.to.partial") + for m in re.finditer(r"@include\(['\"]([^'\"]+)['\"]", src): + tgt = m.group(1).replace(".", "/") + tgt_nid = _make_id(tgt) + if tgt_nid not in {n["id"] for n in nodes}: + nodes.append({"id": tgt_nid, "label": m.group(1), "file_type": "code", + "source_file": str(path), "source_location": None}) + edges.append({"source": file_nid, "target": tgt_nid, "relation": "includes", + "confidence": "EXTRACTED", "confidence_score": 1.0, + "source_file": str(path), "source_location": None, "weight": 1.0}) + + # or + for m in re.finditer(r" dict: + """Extract classes, mixins, functions, imports, and calls from a .dart file using regex.""" + try: + src = path.read_text(encoding="utf-8", errors="replace") + except OSError: + return {"error": f"cannot read {path}"} + + file_nid = _make_id(str(path)) + nodes = [{"id": file_nid, "label": path.name, "file_type": "code", + "source_file": str(path), "source_location": None}] + edges = [] + defined: set[str] = set() + + # Classes and mixins + for m in re.finditer(r"^\s*(?:abstract\s+)?(?:class|mixin)\s+(\w+)", src, re.MULTILINE): + nid = _make_id(str(path), m.group(1)) + if nid not in defined: + nodes.append({"id": nid, "label": m.group(1), "file_type": "code", + "source_file": str(path), "source_location": None}) + edges.append({"source": file_nid, "target": nid, "relation": "defines", + "confidence": "EXTRACTED", "confidence_score": 1.0, + "source_file": str(path), "source_location": None, "weight": 1.0}) + defined.add(nid) + + # Top-level and member functions/methods + for m in re.finditer(r"^\s*(?:static\s+|async\s+)?(?:\w+\s+)+(\w+)\s*\(", src, re.MULTILINE): + name = m.group(1) + if name in {"if", "for", "while", "switch", "catch", "return"}: + continue + nid = _make_id(str(path), name) + if nid not in defined: + nodes.append({"id": nid, "label": name, "file_type": "code", + "source_file": str(path), "source_location": None}) + edges.append({"source": file_nid, "target": nid, "relation": "defines", + "confidence": "EXTRACTED", "confidence_score": 1.0, + "source_file": str(path), "source_location": None, "weight": 1.0}) + defined.add(nid) + + # import 'package:...' or import '...' + for m in re.finditer(r"""^import\s+['"]([^'"]+)['"]""", src, re.MULTILINE): + pkg = m.group(1) + tgt_nid = _make_id(pkg) + if tgt_nid not in defined: + nodes.append({"id": tgt_nid, "label": pkg, "file_type": "code", + "source_file": str(path), "source_location": None}) + defined.add(tgt_nid) + edges.append({"source": file_nid, "target": tgt_nid, "relation": "imports", + "confidence": "EXTRACTED", "confidence_score": 1.0, + "source_file": str(path), "source_location": None, "weight": 1.0}) + + return {"nodes": nodes, "edges": edges} + + +def extract_verilog(path: Path) -> dict: + """Extract modules, functions, tasks, package imports, and instantiations from .v/.sv files.""" + try: + import tree_sitter_verilog as tsverilog + from tree_sitter import Language, Parser + except ImportError: + return {"nodes": [], "edges": [], "error": "tree_sitter_verilog not installed"} + + try: + language = Language(tsverilog.language()) + parser = Parser(language) + source = path.read_bytes() + tree = parser.parse(source) + root = tree.root_node + except Exception as e: + return {"nodes": [], "edges": [], "error": str(e)} + + stem = path.stem + str_path = str(path) + nodes: list[dict] = [] + edges: list[dict] = [] + seen_ids: set[str] = set() + + def add_node(nid: str, label: str, line: int) -> None: + if nid not in seen_ids: + seen_ids.add(nid) + nodes.append({"id": nid, "label": label, "file_type": "code", + "source_file": str_path, "source_location": f"L{line}", + "confidence_score": 1.0}) + + def add_edge(src: str, tgt: str, relation: str, line: int, + confidence: str = "EXTRACTED", score: float = 1.0) -> None: + edges.append({"source": src, "target": tgt, "relation": relation, + "confidence": confidence, "confidence_score": score, + "source_file": str_path, "source_location": f"L{line}", "weight": 1.0}) + + file_nid = _make_id(str(path)) + add_node(file_nid, path.name, 1) + + def walk(node, module_nid: str | None = None) -> None: + t = node.type + + if t == "module_declaration": + name_node = node.child_by_field_name("name") + if name_node: + mod_name = _read_text(name_node, source) + line = node.start_point[0] + 1 + nid = _make_id(stem, mod_name) + add_node(nid, mod_name, line) + add_edge(file_nid, nid, "defines", line) + for child in node.children: + walk(child, nid) + return + + elif t in ("function_declaration", "function_prototype"): + name_node = node.child_by_field_name("name") + if name_node: + func_name = _read_text(name_node, source) + line = node.start_point[0] + 1 + parent = module_nid or file_nid + nid = _make_id(parent, func_name) + add_node(nid, f"{func_name}()", line) + add_edge(parent, nid, "contains", line) + + elif t == "task_declaration": + name_node = node.child_by_field_name("name") + if name_node: + task_name = _read_text(name_node, source) + line = node.start_point[0] + 1 + parent = module_nid or file_nid + nid = _make_id(parent, task_name) + add_node(nid, task_name, line) + add_edge(parent, nid, "contains", line) + + elif t == "package_import_declaration": + for child in node.children: + if child.type == "package_import_item": + pkg_text = _read_text(child, source) + pkg_name = pkg_text.split("::")[0].strip() + if pkg_name: + line = node.start_point[0] + 1 + tgt_nid = _make_id(pkg_name) + add_node(tgt_nid, pkg_name, line) + src = module_nid or file_nid + add_edge(src, tgt_nid, "imports_from", line) + + elif t == "module_instantiation": + # module_type instantiates another module + type_node = node.child_by_field_name("module_type") + if type_node and module_nid: + inst_type = _read_text(type_node, source).strip() + if inst_type: + line = node.start_point[0] + 1 + tgt_nid = _make_id(inst_type) + add_node(tgt_nid, inst_type, line) + add_edge(module_nid, tgt_nid, "instantiates", line) for child in node.children: - walk(child, parent_class_nid=None) + walk(child, module_nid) walk(root) + return {"nodes": nodes, "edges": edges} - label_to_nid: dict[str, str] = {} - for n in nodes: - raw = n["label"] - normalised = raw.strip("()").lstrip(".") - label_to_nid[normalised.lower()] = n["id"] - seen_call_pairs: set[tuple[str, str]] = set() +def extract_lua(path: Path) -> dict: + """Extract functions, methods, require() imports, and calls from a .lua file.""" + return _extract_generic(path, _LUA_CONFIG) - def walk_calls(node, caller_nid: str) -> None: - if node.type in ("function_declaration", "arrow_function", "method_definition"): + +def extract_swift(path: Path) -> dict: + """Extract classes, structs, protocols, functions, imports, and calls from a .swift file.""" + return _extract_generic(path, _SWIFT_CONFIG) + + +# ── Julia extractor (custom walk) ──────────────────────────────────────────── + +def extract_julia(path: Path) -> dict: + """Extract modules, structs, functions, imports, and calls from a .jl file.""" + try: + import tree_sitter_julia as tsjulia + from tree_sitter import Language, Parser + except ImportError: + return {"nodes": [], "edges": [], "error": "tree-sitter-julia not installed"} + + try: + language = Language(tsjulia.language()) + parser = Parser(language) + source = path.read_bytes() + tree = parser.parse(source) + root = tree.root_node + except Exception as e: + return {"nodes": [], "edges": [], "error": str(e)} + + stem = path.stem + str_path = str(path) + nodes: list[dict] = [] + edges: list[dict] = [] + seen_ids: set[str] = set() + function_bodies: list[tuple[str, object]] = [] + + def add_node(nid: str, label: str, line: int) -> None: + if nid not in seen_ids: + seen_ids.add(nid) + nodes.append({ + "id": nid, + "label": label, + "file_type": "code", + "source_file": str_path, + "source_location": f"L{line}", + }) + + def add_edge(src: str, tgt: str, relation: str, line: int, + confidence: str = "EXTRACTED", weight: float = 1.0) -> None: + edges.append({ + "source": src, + "target": tgt, + "relation": relation, + "confidence": confidence, + "source_file": str_path, + "source_location": f"L{line}", + "weight": weight, + }) + + file_nid = _make_id(str(path)) + add_node(file_nid, path.name, 1) + + def _func_name_from_signature(sig_node) -> str | None: + """Extract function name from a Julia signature node (call_expression > identifier).""" + for child in sig_node.children: + if child.type == "call_expression": + callee = child.children[0] if child.children else None + if callee and callee.type == "identifier": + return _read_text(callee, source) + return None + + def walk_calls(body_node, func_nid: str) -> None: + if body_node is None: return - if node.type == "call_expression": - func_node = node.child_by_field_name("function") - callee_name: str | None = None - if func_node: - if func_node.type == "identifier": - callee_name = source[func_node.start_byte:func_node.end_byte].decode("utf-8", errors="replace") - elif func_node.type == "member_expression": - prop = func_node.child_by_field_name("property") - if prop: - callee_name = source[prop.start_byte:prop.end_byte].decode("utf-8", errors="replace") - if callee_name: - tgt_nid = label_to_nid.get(callee_name.lower()) - if tgt_nid and tgt_nid != caller_nid: - pair = (caller_nid, tgt_nid) - if pair not in seen_call_pairs: - seen_call_pairs.add(pair) + t = body_node.type + if t in ("function_definition", "short_function_definition"): + return + if t == "call_expression" and body_node.children: + callee = body_node.children[0] + # Direct call: foo(...) + if callee.type == "identifier": + callee_name = _read_text(callee, source) + target_nid = _make_id(stem, callee_name) + add_edge(func_nid, target_nid, "calls", body_node.start_point[0] + 1, + confidence="EXTRACTED") + # Method call: obj.method(...) + elif callee.type == "field_expression" and len(callee.children) >= 3: + method_node = callee.children[-1] + method_name = _read_text(method_node, source) + target_nid = _make_id(stem, method_name) + add_edge(func_nid, target_nid, "calls", body_node.start_point[0] + 1, + confidence="EXTRACTED") + for child in body_node.children: + walk_calls(child, func_nid) + + def walk(node, scope_nid: str) -> None: + t = node.type + + # Module + if t == "module_definition": + name_node = next((c for c in node.children if c.type == "identifier"), None) + if name_node: + mod_name = _read_text(name_node, source) + mod_nid = _make_id(stem, mod_name) + line = node.start_point[0] + 1 + add_node(mod_nid, mod_name, line) + add_edge(file_nid, mod_nid, "defines", line) + for child in node.children: + walk(child, mod_nid) + return + + # Struct (struct / mutable struct — both map to struct_definition in tree-sitter-julia) + if t == "struct_definition": + # type_head may contain: identifier (simple) or binary_expression (Foo <: Bar) + type_head = next((c for c in node.children if c.type == "type_head"), None) + if type_head: + bin_expr = next((c for c in type_head.children if c.type == "binary_expression"), None) + if bin_expr: + # First identifier is the struct name, last is the supertype + identifiers = [c for c in bin_expr.children if c.type == "identifier"] + if identifiers: + struct_name = _read_text(identifiers[0], source) + struct_nid = _make_id(stem, struct_name) line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "calls", - "confidence": "INFERRED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 0.8, - }) - for child in node.children: - walk_calls(child, caller_nid) + add_node(struct_nid, struct_name, line) + add_edge(scope_nid, struct_nid, "defines", line) + if len(identifiers) >= 2: + super_name = _read_text(identifiers[-1], source) + add_edge(struct_nid, _make_id(stem, super_name), "inherits", + line, confidence="EXTRACTED") + else: + name_node = next((c for c in type_head.children if c.type == "identifier"), None) + if name_node: + struct_name = _read_text(name_node, source) + struct_nid = _make_id(stem, struct_name) + line = node.start_point[0] + 1 + add_node(struct_nid, struct_name, line) + add_edge(scope_nid, struct_nid, "defines", line) + return - for caller_nid, body_node in function_bodies: - walk_calls(body_node, caller_nid) + # Abstract type + if t == "abstract_definition": + # type_head > identifier + type_head = next((c for c in node.children if c.type == "type_head"), None) + if type_head: + name_node = next((c for c in type_head.children if c.type == "identifier"), None) + if name_node: + abs_name = _read_text(name_node, source) + abs_nid = _make_id(stem, abs_name) + line = node.start_point[0] + 1 + add_node(abs_nid, abs_name, line) + add_edge(scope_nid, abs_nid, "defines", line) + return - valid_ids = seen_ids - clean_edges = [] - for edge in edges: - src, tgt = edge["source"], edge["target"] - if src in valid_ids and (tgt in valid_ids or edge["relation"] in ("imports", "imports_from")): - clean_edges.append(edge) + # Function: function foo(...) ... end + if t == "function_definition": + sig_node = next((c for c in node.children if c.type == "signature"), None) + if sig_node: + func_name = _func_name_from_signature(sig_node) + if func_name: + func_nid = _make_id(stem, func_name) + line = node.start_point[0] + 1 + add_node(func_nid, f"{func_name}()", line) + add_edge(scope_nid, func_nid, "defines", line) + function_bodies.append((func_nid, node)) + return - return {"nodes": nodes, "edges": clean_edges} + # Short function: foo(x) = expr + if t == "assignment": + lhs = node.children[0] if node.children else None + if lhs and lhs.type == "call_expression" and lhs.children: + callee = lhs.children[0] + if callee.type == "identifier": + func_name = _read_text(callee, source) + func_nid = _make_id(stem, func_name) + line = node.start_point[0] + 1 + add_node(func_nid, f"{func_name}()", line) + add_edge(scope_nid, func_nid, "defines", line) + # Only walk the RHS (index 2 after lhs and operator) to avoid self-loops + rhs = node.children[-1] if len(node.children) >= 3 else None + if rhs: + function_bodies.append((func_nid, rhs)) + return + + # Using / Import + if t in ("using_statement", "import_statement"): + line = node.start_point[0] + 1 + for child in node.children: + if child.type == "identifier": + mod_name = _read_text(child, source) + imp_nid = _make_id(mod_name) + add_node(imp_nid, mod_name, line) + add_edge(scope_nid, imp_nid, "imports", line) + elif child.type == "selected_import": + identifiers = [c for c in child.children if c.type == "identifier"] + if identifiers: + pkg_name = _read_text(identifiers[0], source) + pkg_nid = _make_id(pkg_name) + add_node(pkg_nid, pkg_name, line) + add_edge(scope_nid, pkg_nid, "imports", line) + return + + for child in node.children: + walk(child, scope_nid) + + walk(root, file_nid) + + for func_nid, body_node in function_bodies: + # For function_definition nodes, walk children directly to avoid + # the boundary check returning early on the top-level node itself. + # Skip the "signature" child — it contains the function's own call_expression + # which would create a self-loop. + if body_node.type == "function_definition": + for child in body_node.children: + if child.type != "signature": + walk_calls(child, func_nid) + else: + walk_calls(body_node, func_nid) + + return {"nodes": nodes, "edges": edges} +# ── Go extractor (custom walk) ──────────────────────────────────────────────── + def extract_go(path: Path) -> dict: """Extract functions, methods, type declarations, and imports from a .go file.""" try: @@ -417,10 +1870,14 @@ def extract_go(path: Path) -> dict: return {"nodes": [], "edges": [], "error": str(e)} stem = path.stem + # Use directory name as package scope so methods on the same type across + # multiple files in a package share one canonical type node. + pkg_scope = path.parent.name or stem str_path = str(path) nodes: list[dict] = [] edges: list[dict] = [] seen_ids: set[str] = set() + function_bodies: list[tuple[str, object]] = [] def add_node(nid: str, label: str, line: int) -> None: if nid not in seen_ids: @@ -433,7 +1890,8 @@ def add_node(nid: str, label: str, line: int) -> None: "source_location": f"L{line}", }) - def add_edge_raw(src: str, tgt: str, relation: str, line: int, confidence: str = "EXTRACTED", weight: float = 1.0) -> None: + def add_edge(src: str, tgt: str, relation: str, line: int, + confidence: str = "EXTRACTED", weight: float = 1.0) -> None: edges.append({ "source": src, "target": tgt, @@ -444,22 +1902,20 @@ def add_edge_raw(src: str, tgt: str, relation: str, line: int, confidence: str = "weight": weight, }) - file_nid = _make_id(stem) + file_nid = _make_id(str(path)) add_node(file_nid, path.name, 1) - function_bodies: list[tuple[str, object]] = [] - def walk(node) -> None: t = node.type if t == "function_declaration": name_node = node.child_by_field_name("name") if name_node: - func_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") + func_name = _read_text(name_node, source) line = node.start_point[0] + 1 func_nid = _make_id(stem, func_name) add_node(func_nid, f"{func_name}()", line) - add_edge_raw(file_nid, func_nid, "contains", line) + add_edge(file_nid, func_nid, "contains", line) body = node.child_by_field_name("body") if body: function_bodies.append((func_nid, body)) @@ -473,23 +1929,23 @@ def walk(node) -> None: if param.type == "parameter_declaration": type_node = param.child_by_field_name("type") if type_node: - raw = source[type_node.start_byte:type_node.end_byte].decode("utf-8", errors="replace").lstrip("*").strip() + raw = _read_text(type_node, source).lstrip("*").strip() receiver_type = raw break name_node = node.child_by_field_name("name") if name_node: - method_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") + method_name = _read_text(name_node, source) line = node.start_point[0] + 1 if receiver_type: - parent_nid = _make_id(stem, receiver_type) + parent_nid = _make_id(pkg_scope, receiver_type) add_node(parent_nid, receiver_type, line) method_nid = _make_id(parent_nid, method_name) add_node(method_nid, f".{method_name}()", line) - add_edge_raw(parent_nid, method_nid, "method", line) + add_edge(parent_nid, method_nid, "method", line) else: method_nid = _make_id(stem, method_name) add_node(method_nid, f"{method_name}()", line) - add_edge_raw(file_nid, method_nid, "contains", line) + add_edge(file_nid, method_nid, "contains", line) body = node.child_by_field_name("body") if body: function_bodies.append((method_nid, body)) @@ -500,11 +1956,11 @@ def walk(node) -> None: if child.type == "type_spec": name_node = child.child_by_field_name("name") if name_node: - type_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") + type_name = _read_text(name_node, source) line = child.start_point[0] + 1 - type_nid = _make_id(stem, type_name) + type_nid = _make_id(pkg_scope, type_name) add_node(type_nid, type_name, line) - add_edge_raw(file_nid, type_nid, "contains", line) + add_edge(file_nid, type_nid, "contains", line) return if t == "import_declaration": @@ -514,17 +1970,17 @@ def walk(node) -> None: if spec.type == "import_spec": path_node = spec.child_by_field_name("path") if path_node: - raw = source[path_node.start_byte:path_node.end_byte].decode("utf-8", errors="replace").strip('"') + raw = _read_text(path_node, source).strip('"') module_name = raw.split("/")[-1] tgt_nid = _make_id(module_name) - add_edge_raw(file_nid, tgt_nid, "imports_from", spec.start_point[0] + 1) + add_edge(file_nid, tgt_nid, "imports_from", spec.start_point[0] + 1) elif child.type == "import_spec": path_node = child.child_by_field_name("path") if path_node: - raw = source[path_node.start_byte:path_node.end_byte].decode("utf-8", errors="replace").strip('"') + raw = _read_text(path_node, source).strip('"') module_name = raw.split("/")[-1] tgt_nid = _make_id(module_name) - add_edge_raw(file_nid, tgt_nid, "imports_from", child.start_point[0] + 1) + add_edge(file_nid, tgt_nid, "imports_from", child.start_point[0] + 1) return for child in node.children: @@ -548,1011 +2004,11 @@ def walk_calls(node, caller_nid: str) -> None: callee_name: str | None = None if func_node: if func_node.type == "identifier": - callee_name = source[func_node.start_byte:func_node.end_byte].decode("utf-8", errors="replace") + callee_name = _read_text(func_node, source) elif func_node.type == "selector_expression": field = func_node.child_by_field_name("field") if field: - callee_name = source[field.start_byte:field.end_byte].decode("utf-8", errors="replace") - if callee_name: - tgt_nid = label_to_nid.get(callee_name.lower()) - if tgt_nid and tgt_nid != caller_nid: - pair = (caller_nid, tgt_nid) - if pair not in seen_call_pairs: - seen_call_pairs.add(pair) - line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "calls", - "confidence": "INFERRED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 0.8, - }) - for child in node.children: - walk_calls(child, caller_nid) - - for caller_nid, body_node in function_bodies: - walk_calls(body_node, caller_nid) - - valid_ids = seen_ids - clean_edges = [] - for edge in edges: - src, tgt = edge["source"], edge["target"] - if src in valid_ids and (tgt in valid_ids or edge["relation"] in ("imports", "imports_from")): - clean_edges.append(edge) - - return {"nodes": nodes, "edges": clean_edges} - - -def extract_rust(path: Path) -> dict: - """Extract functions, structs, enums, traits, impl methods, and use declarations from a .rs file.""" - try: - import tree_sitter_rust as tsrust - from tree_sitter import Language, Parser - except ImportError: - return {"nodes": [], "edges": [], "error": "tree-sitter-rust not installed"} - - try: - language = Language(tsrust.language()) - parser = Parser(language) - source = path.read_bytes() - tree = parser.parse(source) - root = tree.root_node - except Exception as e: - return {"nodes": [], "edges": [], "error": str(e)} - - stem = path.stem - str_path = str(path) - nodes: list[dict] = [] - edges: list[dict] = [] - seen_ids: set[str] = set() - - def add_node(nid: str, label: str, line: int) -> None: - if nid not in seen_ids: - seen_ids.add(nid) - nodes.append({ - "id": nid, - "label": label, - "file_type": "code", - "source_file": str_path, - "source_location": f"L{line}", - }) - - def add_edge(src: str, tgt: str, relation: str, line: int, confidence: str = "EXTRACTED", weight: float = 1.0) -> None: - edges.append({ - "source": src, - "target": tgt, - "relation": relation, - "confidence": confidence, - "source_file": str_path, - "source_location": f"L{line}", - "weight": weight, - }) - - file_nid = _make_id(stem) - add_node(file_nid, path.name, 1) - - function_bodies: list[tuple[str, object]] = [] - - def walk(node, parent_impl_nid: str | None = None) -> None: - t = node.type - - if t == "function_item": - name_node = node.child_by_field_name("name") - if name_node: - func_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - line = node.start_point[0] + 1 - if parent_impl_nid: - func_nid = _make_id(parent_impl_nid, func_name) - add_node(func_nid, f".{func_name}()", line) - add_edge(parent_impl_nid, func_nid, "method", line) - else: - func_nid = _make_id(stem, func_name) - add_node(func_nid, f"{func_name}()", line) - add_edge(file_nid, func_nid, "contains", line) - body = node.child_by_field_name("body") - if body: - function_bodies.append((func_nid, body)) - return - - if t in ("struct_item", "enum_item", "trait_item"): - name_node = node.child_by_field_name("name") - if name_node: - item_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - line = node.start_point[0] + 1 - item_nid = _make_id(stem, item_name) - add_node(item_nid, item_name, line) - add_edge(file_nid, item_nid, "contains", line) - return - - if t == "impl_item": - type_node = node.child_by_field_name("type") - impl_nid: str | None = None - if type_node: - type_name = source[type_node.start_byte:type_node.end_byte].decode("utf-8", errors="replace").strip() - impl_nid = _make_id(stem, type_name) - add_node(impl_nid, type_name, node.start_point[0] + 1) - body = node.child_by_field_name("body") - if body: - for child in body.children: - walk(child, parent_impl_nid=impl_nid) - return - - if t == "use_declaration": - arg = node.child_by_field_name("argument") - if arg: - raw = source[arg.start_byte:arg.end_byte].decode("utf-8", errors="replace") - clean = raw.split("{")[0].rstrip(":").rstrip("*").rstrip(":") - module_name = clean.split("::")[-1].strip() - if module_name: - tgt_nid = _make_id(module_name) - add_edge(file_nid, tgt_nid, "imports_from", node.start_point[0] + 1) - return - - for child in node.children: - walk(child, parent_impl_nid=None) - - walk(root) - - label_to_nid: dict[str, str] = {} - for n in nodes: - raw = n["label"] - normalised = raw.strip("()").lstrip(".") - label_to_nid[normalised.lower()] = n["id"] - - seen_call_pairs: set[tuple[str, str]] = set() - - def walk_calls(node, caller_nid: str) -> None: - if node.type == "function_item": - return - if node.type == "call_expression": - func_node = node.child_by_field_name("function") - callee_name: str | None = None - if func_node: - if func_node.type == "identifier": - callee_name = source[func_node.start_byte:func_node.end_byte].decode("utf-8", errors="replace") - elif func_node.type == "field_expression": - field = func_node.child_by_field_name("field") - if field: - callee_name = source[field.start_byte:field.end_byte].decode("utf-8", errors="replace") - elif func_node.type == "scoped_identifier": - name = func_node.child_by_field_name("name") - if name: - callee_name = source[name.start_byte:name.end_byte].decode("utf-8", errors="replace") - if callee_name: - tgt_nid = label_to_nid.get(callee_name.lower()) - if tgt_nid and tgt_nid != caller_nid: - pair = (caller_nid, tgt_nid) - if pair not in seen_call_pairs: - seen_call_pairs.add(pair) - line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "calls", - "confidence": "INFERRED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 0.8, - }) - for child in node.children: - walk_calls(child, caller_nid) - - for caller_nid, body_node in function_bodies: - walk_calls(body_node, caller_nid) - - valid_ids = seen_ids - clean_edges = [] - for edge in edges: - src, tgt = edge["source"], edge["target"] - if src in valid_ids and (tgt in valid_ids or edge["relation"] in ("imports", "imports_from")): - clean_edges.append(edge) - - return {"nodes": nodes, "edges": clean_edges} - - -def extract_java(path: Path) -> dict: - """Extract classes, interfaces, methods, constructors, and imports from a .java file.""" - try: - import tree_sitter_java as tsjava - from tree_sitter import Language, Parser - except ImportError: - return {"nodes": [], "edges": [], "error": "tree-sitter-java not installed"} - - try: - language = Language(tsjava.language()) - parser = Parser(language) - source = path.read_bytes() - tree = parser.parse(source) - root = tree.root_node - except Exception as e: - return {"nodes": [], "edges": [], "error": str(e)} - - stem = path.stem - str_path = str(path) - nodes: list[dict] = [] - edges: list[dict] = [] - seen_ids: set[str] = set() - - def add_node(nid: str, label: str, line: int) -> None: - if nid not in seen_ids: - seen_ids.add(nid) - nodes.append({ - "id": nid, - "label": label, - "file_type": "code", - "source_file": str_path, - "source_location": f"L{line}", - }) - - def add_edge_raw(src: str, tgt: str, relation: str, line: int, confidence: str = "EXTRACTED", weight: float = 1.0) -> None: - edges.append({ - "source": src, - "target": tgt, - "relation": relation, - "confidence": confidence, - "source_file": str_path, - "source_location": f"L{line}", - "weight": weight, - }) - - file_nid = _make_id(stem) - add_node(file_nid, path.name, 1) - - function_bodies: list[tuple[str, object]] = [] - - def _walk_scoped_identifier(node) -> str: - """Reconstruct a dotted import path from nested scoped_identifier nodes.""" - parts: list[str] = [] - cur = node - while cur: - if cur.type == "scoped_identifier": - name_node = cur.child_by_field_name("name") - if name_node: - parts.append(source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace")) - cur = cur.child_by_field_name("scope") - elif cur.type == "identifier": - parts.append(source[cur.start_byte:cur.end_byte].decode("utf-8", errors="replace")) - break - else: - break - parts.reverse() - return ".".join(parts) - - def walk(node, parent_class_nid: str | None = None) -> None: - t = node.type - - if t == "import_declaration": - # Find scoped_identifier or identifier child - for child in node.children: - if child.type in ("scoped_identifier", "identifier"): - path_str = _walk_scoped_identifier(child) - module_name = path_str.split(".")[-1].strip("*").strip(".") or path_str.split(".")[-2] - if module_name: - tgt_nid = _make_id(module_name) - add_edge_raw(file_nid, tgt_nid, "imports", node.start_point[0] + 1) - break - return - - if t in ("class_declaration", "interface_declaration"): - name_node = node.child_by_field_name("name") - if not name_node: - return - class_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - class_nid = _make_id(stem, class_name) - line = node.start_point[0] + 1 - add_node(class_nid, class_name, line) - add_edge_raw(file_nid, class_nid, "contains", line) - body = node.child_by_field_name("body") - if body: - for child in body.children: - walk(child, parent_class_nid=class_nid) - return - - if t in ("method_declaration", "constructor_declaration"): - name_node = node.child_by_field_name("name") - if not name_node: - return - method_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - line = node.start_point[0] + 1 - if parent_class_nid: - method_nid = _make_id(parent_class_nid, method_name) - add_node(method_nid, f".{method_name}()", line) - add_edge_raw(parent_class_nid, method_nid, "method", line) - else: - method_nid = _make_id(stem, method_name) - add_node(method_nid, f"{method_name}()", line) - add_edge_raw(file_nid, method_nid, "contains", line) - body = node.child_by_field_name("body") - if body: - function_bodies.append((method_nid, body)) - return - - for child in node.children: - walk(child, parent_class_nid=None) - - walk(root) - - label_to_nid: dict[str, str] = {} - for n in nodes: - raw = n["label"] - normalised = raw.strip("()").lstrip(".") - label_to_nid[normalised.lower()] = n["id"] - - seen_call_pairs: set[tuple[str, str]] = set() - - def walk_calls(node, caller_nid: str) -> None: - if node.type in ("method_declaration", "constructor_declaration"): - return - if node.type == "method_invocation": - name_node = node.child_by_field_name("name") - callee_name: str | None = None - if name_node: - callee_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - if callee_name: - tgt_nid = label_to_nid.get(callee_name.lower()) - if tgt_nid and tgt_nid != caller_nid: - pair = (caller_nid, tgt_nid) - if pair not in seen_call_pairs: - seen_call_pairs.add(pair) - line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "calls", - "confidence": "INFERRED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 0.8, - }) - for child in node.children: - walk_calls(child, caller_nid) - - for caller_nid, body_node in function_bodies: - walk_calls(body_node, caller_nid) - - valid_ids = seen_ids - clean_edges = [] - for edge in edges: - src, tgt = edge["source"], edge["target"] - if src in valid_ids and (tgt in valid_ids or edge["relation"] in ("imports", "imports_from")): - clean_edges.append(edge) - - return {"nodes": nodes, "edges": clean_edges} - - -def extract_c(path: Path) -> dict: - """Extract functions and includes from a .c/.h file.""" - try: - import tree_sitter_c as tsc - from tree_sitter import Language, Parser - except ImportError: - return {"nodes": [], "edges": [], "error": "tree-sitter-c not installed"} - - try: - language = Language(tsc.language()) - parser = Parser(language) - source = path.read_bytes() - tree = parser.parse(source) - root = tree.root_node - except Exception as e: - return {"nodes": [], "edges": [], "error": str(e)} - - stem = path.stem - str_path = str(path) - nodes: list[dict] = [] - edges: list[dict] = [] - seen_ids: set[str] = set() - - def add_node(nid: str, label: str, line: int) -> None: - if nid not in seen_ids: - seen_ids.add(nid) - nodes.append({ - "id": nid, - "label": label, - "file_type": "code", - "source_file": str_path, - "source_location": f"L{line}", - }) - - def add_edge_raw(src: str, tgt: str, relation: str, line: int, confidence: str = "EXTRACTED", weight: float = 1.0) -> None: - edges.append({ - "source": src, - "target": tgt, - "relation": relation, - "confidence": confidence, - "source_file": str_path, - "source_location": f"L{line}", - "weight": weight, - }) - - file_nid = _make_id(stem) - add_node(file_nid, path.name, 1) - - function_bodies: list[tuple[str, object]] = [] - - def _get_func_name_from_declarator(node) -> str | None: - """Recursively unwrap declarator to find the innermost identifier.""" - if node.type == "identifier": - return source[node.start_byte:node.end_byte].decode("utf-8", errors="replace") - decl = node.child_by_field_name("declarator") - if decl: - return _get_func_name_from_declarator(decl) - # fallback: search children for identifier - for child in node.children: - if child.type == "identifier": - return source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") - return None - - def walk(node) -> None: - t = node.type - - if t == "preproc_include": - # path child or string child - for child in node.children: - if child.type in ("string_literal", "system_lib_string", "string"): - raw = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace").strip('"<> ') - module_name = raw.split("/")[-1].split(".")[0] - if module_name: - tgt_nid = _make_id(module_name) - add_edge_raw(file_nid, tgt_nid, "imports", node.start_point[0] + 1) - break - return - - if t == "function_definition": - declarator = node.child_by_field_name("declarator") - func_name: str | None = None - if declarator: - func_name = _get_func_name_from_declarator(declarator) - if func_name: - line = node.start_point[0] + 1 - func_nid = _make_id(stem, func_name) - add_node(func_nid, f"{func_name}()", line) - add_edge_raw(file_nid, func_nid, "contains", line) - body = node.child_by_field_name("body") - if body: - function_bodies.append((func_nid, body)) - return - - for child in node.children: - walk(child) - - walk(root) - - label_to_nid: dict[str, str] = {} - for n in nodes: - raw = n["label"] - normalised = raw.strip("()").lstrip(".") - label_to_nid[normalised.lower()] = n["id"] - - seen_call_pairs: set[tuple[str, str]] = set() - - def walk_calls(node, caller_nid: str) -> None: - if node.type == "function_definition": - return - if node.type == "call_expression": - func_node = node.child_by_field_name("function") - callee_name: str | None = None - if func_node: - if func_node.type == "identifier": - callee_name = source[func_node.start_byte:func_node.end_byte].decode("utf-8", errors="replace") - elif func_node.type == "field_expression": - field = func_node.child_by_field_name("field") - if field: - callee_name = source[field.start_byte:field.end_byte].decode("utf-8", errors="replace") - if callee_name: - tgt_nid = label_to_nid.get(callee_name.lower()) - if tgt_nid and tgt_nid != caller_nid: - pair = (caller_nid, tgt_nid) - if pair not in seen_call_pairs: - seen_call_pairs.add(pair) - line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "calls", - "confidence": "INFERRED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 0.8, - }) - for child in node.children: - walk_calls(child, caller_nid) - - for caller_nid, body_node in function_bodies: - walk_calls(body_node, caller_nid) - - valid_ids = seen_ids - clean_edges = [] - for edge in edges: - src, tgt = edge["source"], edge["target"] - if src in valid_ids and (tgt in valid_ids or edge["relation"] in ("imports", "imports_from")): - clean_edges.append(edge) - - return {"nodes": nodes, "edges": clean_edges} - - -def extract_cpp(path: Path) -> dict: - """Extract functions, classes, and includes from a .cpp/.cc/.cxx/.hpp file.""" - try: - import tree_sitter_cpp as tscpp - from tree_sitter import Language, Parser - except ImportError: - return {"nodes": [], "edges": [], "error": "tree-sitter-cpp not installed"} - - try: - language = Language(tscpp.language()) - parser = Parser(language) - source = path.read_bytes() - tree = parser.parse(source) - root = tree.root_node - except Exception as e: - return {"nodes": [], "edges": [], "error": str(e)} - - stem = path.stem - str_path = str(path) - nodes: list[dict] = [] - edges: list[dict] = [] - seen_ids: set[str] = set() - - def add_node(nid: str, label: str, line: int) -> None: - if nid not in seen_ids: - seen_ids.add(nid) - nodes.append({ - "id": nid, - "label": label, - "file_type": "code", - "source_file": str_path, - "source_location": f"L{line}", - }) - - def add_edge_raw(src: str, tgt: str, relation: str, line: int, confidence: str = "EXTRACTED", weight: float = 1.0) -> None: - edges.append({ - "source": src, - "target": tgt, - "relation": relation, - "confidence": confidence, - "source_file": str_path, - "source_location": f"L{line}", - "weight": weight, - }) - - file_nid = _make_id(stem) - add_node(file_nid, path.name, 1) - - function_bodies: list[tuple[str, object]] = [] - - def _get_func_name_from_declarator(node) -> str | None: - """Recursively unwrap declarator to find the innermost identifier.""" - if node.type == "identifier": - return source[node.start_byte:node.end_byte].decode("utf-8", errors="replace") - if node.type == "qualified_identifier": - name_node = node.child_by_field_name("name") - if name_node: - return source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - decl = node.child_by_field_name("declarator") - if decl: - return _get_func_name_from_declarator(decl) - for child in node.children: - if child.type == "identifier": - return source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") - return None - - def walk(node, parent_class_nid: str | None = None) -> None: - t = node.type - - if t == "preproc_include": - for child in node.children: - if child.type in ("string_literal", "system_lib_string", "string"): - raw = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace").strip('"<> ') - module_name = raw.split("/")[-1].split(".")[0] - if module_name: - tgt_nid = _make_id(module_name) - add_edge_raw(file_nid, tgt_nid, "imports", node.start_point[0] + 1) - break - return - - if t == "class_specifier": - name_node = node.child_by_field_name("name") - if not name_node: - return - class_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - class_nid = _make_id(stem, class_name) - line = node.start_point[0] + 1 - add_node(class_nid, class_name, line) - add_edge_raw(file_nid, class_nid, "contains", line) - body = node.child_by_field_name("body") - if body: - for child in body.children: - walk(child, parent_class_nid=class_nid) - return - - if t == "function_definition": - declarator = node.child_by_field_name("declarator") - func_name: str | None = None - if declarator: - func_name = _get_func_name_from_declarator(declarator) - if func_name: - line = node.start_point[0] + 1 - if parent_class_nid: - func_nid = _make_id(parent_class_nid, func_name) - add_node(func_nid, f".{func_name}()", line) - add_edge_raw(parent_class_nid, func_nid, "method", line) - else: - func_nid = _make_id(stem, func_name) - add_node(func_nid, f"{func_name}()", line) - add_edge_raw(file_nid, func_nid, "contains", line) - body = node.child_by_field_name("body") - if body: - function_bodies.append((func_nid, body)) - return - - for child in node.children: - walk(child, parent_class_nid=None) - - walk(root) - - label_to_nid: dict[str, str] = {} - for n in nodes: - raw = n["label"] - normalised = raw.strip("()").lstrip(".") - label_to_nid[normalised.lower()] = n["id"] - - seen_call_pairs: set[tuple[str, str]] = set() - - def walk_calls(node, caller_nid: str) -> None: - if node.type == "function_definition": - return - if node.type == "call_expression": - func_node = node.child_by_field_name("function") - callee_name: str | None = None - if func_node: - if func_node.type == "identifier": - callee_name = source[func_node.start_byte:func_node.end_byte].decode("utf-8", errors="replace") - elif func_node.type in ("field_expression", "qualified_identifier"): - name = func_node.child_by_field_name("field") or func_node.child_by_field_name("name") - if name: - callee_name = source[name.start_byte:name.end_byte].decode("utf-8", errors="replace") - if callee_name: - tgt_nid = label_to_nid.get(callee_name.lower()) - if tgt_nid and tgt_nid != caller_nid: - pair = (caller_nid, tgt_nid) - if pair not in seen_call_pairs: - seen_call_pairs.add(pair) - line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "calls", - "confidence": "INFERRED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 0.8, - }) - for child in node.children: - walk_calls(child, caller_nid) - - for caller_nid, body_node in function_bodies: - walk_calls(body_node, caller_nid) - - valid_ids = seen_ids - clean_edges = [] - for edge in edges: - src, tgt = edge["source"], edge["target"] - if src in valid_ids and (tgt in valid_ids or edge["relation"] in ("imports", "imports_from")): - clean_edges.append(edge) - - return {"nodes": nodes, "edges": clean_edges} - - -def extract_ruby(path: Path) -> dict: - """Extract classes, methods, singleton methods, and calls from a .rb file.""" - try: - import tree_sitter_ruby as tsruby - from tree_sitter import Language, Parser - except ImportError: - return {"nodes": [], "edges": [], "error": "tree-sitter-ruby not installed"} - - try: - language = Language(tsruby.language()) - parser = Parser(language) - source = path.read_bytes() - tree = parser.parse(source) - root = tree.root_node - except Exception as e: - return {"nodes": [], "edges": [], "error": str(e)} - - stem = path.stem - str_path = str(path) - nodes: list[dict] = [] - edges: list[dict] = [] - seen_ids: set[str] = set() - - def add_node(nid: str, label: str, line: int) -> None: - if nid not in seen_ids: - seen_ids.add(nid) - nodes.append({ - "id": nid, - "label": label, - "file_type": "code", - "source_file": str_path, - "source_location": f"L{line}", - }) - - def add_edge_raw(src: str, tgt: str, relation: str, line: int, confidence: str = "EXTRACTED", weight: float = 1.0) -> None: - edges.append({ - "source": src, - "target": tgt, - "relation": relation, - "confidence": confidence, - "source_file": str_path, - "source_location": f"L{line}", - "weight": weight, - }) - - file_nid = _make_id(stem) - add_node(file_nid, path.name, 1) - - function_bodies: list[tuple[str, object]] = [] - - def walk(node, parent_class_nid: str | None = None) -> None: - t = node.type - - if t == "class": - # name is a child node (not a field in all versions) - name_node = node.child_by_field_name("name") - if name_node is None: - for child in node.children: - if child.type in ("constant", "scope_resolution"): - name_node = child - break - if not name_node: - return - class_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - class_nid = _make_id(stem, class_name) - line = node.start_point[0] + 1 - add_node(class_nid, class_name, line) - add_edge_raw(file_nid, class_nid, "contains", line) - body = node.child_by_field_name("body") - if body is None: - # body may not be a named field - walk all children except first/last - for child in node.children: - if child.type == "body_statement": - body = child - break - if body: - for child in body.children: - walk(child, parent_class_nid=class_nid) - return - - if t in ("method", "singleton_method"): - name_node = node.child_by_field_name("name") - if name_node is None: - for child in node.children: - if child.type == "identifier": - name_node = child - break - if not name_node: - return - method_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - line = node.start_point[0] + 1 - if parent_class_nid: - method_nid = _make_id(parent_class_nid, method_name) - add_node(method_nid, f".{method_name}()", line) - add_edge_raw(parent_class_nid, method_nid, "method", line) - else: - method_nid = _make_id(stem, method_name) - add_node(method_nid, f"{method_name}()", line) - add_edge_raw(file_nid, method_nid, "contains", line) - body = node.child_by_field_name("body") - if body is None: - for child in node.children: - if child.type == "body_statement": - body = child - break - if body: - function_bodies.append((method_nid, body)) - return - - for child in node.children: - walk(child, parent_class_nid=None) - - walk(root) - - label_to_nid: dict[str, str] = {} - for n in nodes: - raw = n["label"] - normalised = raw.strip("()").lstrip(".") - label_to_nid[normalised.lower()] = n["id"] - - seen_call_pairs: set[tuple[str, str]] = set() - - def walk_calls(node, caller_nid: str) -> None: - if node.type in ("method", "singleton_method"): - return - if node.type == "call": - method_node = node.child_by_field_name("method") - callee_name: str | None = None - if method_node: - callee_name = source[method_node.start_byte:method_node.end_byte].decode("utf-8", errors="replace") - if callee_name: - tgt_nid = label_to_nid.get(callee_name.lower()) - if tgt_nid and tgt_nid != caller_nid: - pair = (caller_nid, tgt_nid) - if pair not in seen_call_pairs: - seen_call_pairs.add(pair) - line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "calls", - "confidence": "INFERRED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 0.8, - }) - for child in node.children: - walk_calls(child, caller_nid) - - for caller_nid, body_node in function_bodies: - walk_calls(body_node, caller_nid) - - valid_ids = seen_ids - clean_edges = [] - for edge in edges: - src, tgt = edge["source"], edge["target"] - if src in valid_ids and (tgt in valid_ids or edge["relation"] in ("imports", "imports_from")): - clean_edges.append(edge) - - return {"nodes": nodes, "edges": clean_edges} - - -def extract_csharp(path: Path) -> dict: - """Extract classes, interfaces, methods, namespaces, and usings from a .cs file.""" - try: - import tree_sitter_c_sharp as tscsharp - from tree_sitter import Language, Parser - except ImportError: - return {"nodes": [], "edges": [], "error": "tree-sitter-c-sharp not installed"} - - try: - language = Language(tscsharp.language()) - parser = Parser(language) - source = path.read_bytes() - tree = parser.parse(source) - root = tree.root_node - except Exception as e: - return {"nodes": [], "edges": [], "error": str(e)} - - stem = path.stem - str_path = str(path) - nodes: list[dict] = [] - edges: list[dict] = [] - seen_ids: set[str] = set() - - def add_node(nid: str, label: str, line: int) -> None: - if nid not in seen_ids: - seen_ids.add(nid) - nodes.append({ - "id": nid, - "label": label, - "file_type": "code", - "source_file": str_path, - "source_location": f"L{line}", - }) - - def add_edge_raw(src: str, tgt: str, relation: str, line: int, confidence: str = "EXTRACTED", weight: float = 1.0) -> None: - edges.append({ - "source": src, - "target": tgt, - "relation": relation, - "confidence": confidence, - "source_file": str_path, - "source_location": f"L{line}", - "weight": weight, - }) - - file_nid = _make_id(stem) - add_node(file_nid, path.name, 1) - - function_bodies: list[tuple[str, object]] = [] - - def walk(node, parent_class_nid: str | None = None) -> None: - t = node.type - - if t == "using_directive": - # Extract the namespace name from the using directive - for child in node.children: - if child.type in ("qualified_name", "identifier", "name_equals"): - raw = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") - module_name = raw.split(".")[-1].strip() - if module_name: - tgt_nid = _make_id(module_name) - add_edge_raw(file_nid, tgt_nid, "imports", node.start_point[0] + 1) - break - return - - if t == "namespace_declaration": - name_node = node.child_by_field_name("name") - if name_node: - ns_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - ns_nid = _make_id(stem, ns_name) - line = node.start_point[0] + 1 - add_node(ns_nid, ns_name, line) - add_edge_raw(file_nid, ns_nid, "contains", line) - body = node.child_by_field_name("body") - if body: - for child in body.children: - walk(child, parent_class_nid=parent_class_nid) - return - - if t in ("class_declaration", "interface_declaration"): - name_node = node.child_by_field_name("name") - if not name_node: - return - class_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - class_nid = _make_id(stem, class_name) - line = node.start_point[0] + 1 - add_node(class_nid, class_name, line) - add_edge_raw(file_nid, class_nid, "contains", line) - body = node.child_by_field_name("body") - if body: - for child in body.children: - walk(child, parent_class_nid=class_nid) - return - - if t == "method_declaration": - name_node = node.child_by_field_name("name") - if not name_node: - return - method_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - line = node.start_point[0] + 1 - if parent_class_nid: - method_nid = _make_id(parent_class_nid, method_name) - add_node(method_nid, f".{method_name}()", line) - add_edge_raw(parent_class_nid, method_nid, "method", line) - else: - method_nid = _make_id(stem, method_name) - add_node(method_nid, f"{method_name}()", line) - add_edge_raw(file_nid, method_nid, "contains", line) - body = node.child_by_field_name("body") - if body: - function_bodies.append((method_nid, body)) - return - - for child in node.children: - walk(child, parent_class_nid=None) - - walk(root) - - label_to_nid: dict[str, str] = {} - for n in nodes: - raw = n["label"] - normalised = raw.strip("()").lstrip(".") - label_to_nid[normalised.lower()] = n["id"] - - seen_call_pairs: set[tuple[str, str]] = set() - - def walk_calls(node, caller_nid: str) -> None: - if node.type == "method_declaration": - return - if node.type == "invocation_expression": - callee_name: str | None = None - name_node = node.child_by_field_name("name") - if name_node: - callee_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - else: - # Try first named child - for child in node.children: - if child.is_named: - raw = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") - # member_access_expression: strip the object prefix - if "." in raw: - callee_name = raw.split(".")[-1] - else: - callee_name = raw - break + callee_name = _read_text(field, source) if callee_name: tgt_nid = label_to_nid.get(callee_name.lower()) if tgt_nid and tgt_nid != caller_nid: @@ -1564,10 +2020,10 @@ def walk_calls(node, caller_nid: str) -> None: "source": caller_nid, "target": tgt_nid, "relation": "calls", - "confidence": "INFERRED", + "confidence": "EXTRACTED", "source_file": str_path, "source_location": f"L{line}", - "weight": 0.8, + "weight": 1.0, }) for child in node.children: walk_calls(child, caller_nid) @@ -1585,16 +2041,18 @@ def walk_calls(node, caller_nid: str) -> None: return {"nodes": nodes, "edges": clean_edges} -def extract_kotlin(path: Path) -> dict: - """Extract classes, objects, functions, and imports from a .kt/.kts file.""" +# ── Rust extractor (custom walk) ────────────────────────────────────────────── + +def extract_rust(path: Path) -> dict: + """Extract functions, structs, enums, traits, impl methods, and use declarations from a .rs file.""" try: - import tree_sitter_kotlin as tskotlin + import tree_sitter_rust as tsrust from tree_sitter import Language, Parser except ImportError: - return {"nodes": [], "edges": [], "error": "tree-sitter-kotlin not installed"} + return {"nodes": [], "edges": [], "error": "tree-sitter-rust not installed"} try: - language = Language(tskotlin.language()) + language = Language(tsrust.language()) parser = Parser(language) source = path.read_bytes() tree = parser.parse(source) @@ -1607,6 +2065,7 @@ def extract_kotlin(path: Path) -> dict: nodes: list[dict] = [] edges: list[dict] = [] seen_ids: set[str] = set() + function_bodies: list[tuple[str, object]] = [] def add_node(nid: str, label: str, line: int) -> None: if nid not in seen_ids: @@ -1619,7 +2078,8 @@ def add_node(nid: str, label: str, line: int) -> None: "source_location": f"L{line}", }) - def add_edge_raw(src: str, tgt: str, relation: str, line: int, confidence: str = "EXTRACTED", weight: float = 1.0) -> None: + def add_edge(src: str, tgt: str, relation: str, line: int, + confidence: str = "EXTRACTED", weight: float = 1.0) -> None: edges.append({ "source": src, "target": tgt, @@ -1630,79 +2090,66 @@ def add_edge_raw(src: str, tgt: str, relation: str, line: int, confidence: str = "weight": weight, }) - file_nid = _make_id(stem) + file_nid = _make_id(str(path)) add_node(file_nid, path.name, 1) - function_bodies: list[tuple[str, object]] = [] - - def walk(node, parent_class_nid: str | None = None) -> None: + def walk(node, parent_impl_nid: str | None = None) -> None: t = node.type - if t == "import_header": - for child in node.children: - if child.type == "identifier": - raw = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") - tgt_nid = _make_id(raw) - add_edge_raw(file_nid, tgt_nid, "imports", node.start_point[0] + 1) - break + if t == "function_item": + name_node = node.child_by_field_name("name") + if name_node: + func_name = _read_text(name_node, source) + line = node.start_point[0] + 1 + if parent_impl_nid: + func_nid = _make_id(parent_impl_nid, func_name) + add_node(func_nid, f".{func_name}()", line) + add_edge(parent_impl_nid, func_nid, "method", line) + else: + func_nid = _make_id(stem, func_name) + add_node(func_nid, f"{func_name}()", line) + add_edge(file_nid, func_nid, "contains", line) + body = node.child_by_field_name("body") + if body: + function_bodies.append((func_nid, body)) return - if t in ("class_declaration", "object_declaration"): + if t in ("struct_item", "enum_item", "trait_item"): name_node = node.child_by_field_name("name") - if name_node is None: - for child in node.children: - if child.type == "simple_identifier": - name_node = child - break - if not name_node: - return - class_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - class_nid = _make_id(stem, class_name) - line = node.start_point[0] + 1 - add_node(class_nid, class_name, line) - add_edge_raw(file_nid, class_nid, "contains", line) + if name_node: + item_name = _read_text(name_node, source) + line = node.start_point[0] + 1 + item_nid = _make_id(stem, item_name) + add_node(item_nid, item_name, line) + add_edge(file_nid, item_nid, "contains", line) + return + + if t == "impl_item": + type_node = node.child_by_field_name("type") + impl_nid: str | None = None + if type_node: + type_name = _read_text(type_node, source).strip() + impl_nid = _make_id(stem, type_name) + add_node(impl_nid, type_name, node.start_point[0] + 1) body = node.child_by_field_name("body") - if body is None: - for child in node.children: - if child.type == "class_body": - body = child - break if body: for child in body.children: - walk(child, parent_class_nid=class_nid) + walk(child, parent_impl_nid=impl_nid) return - if t == "function_declaration": - name_node = node.child_by_field_name("name") - if name_node is None: - for child in node.children: - if child.type == "simple_identifier": - name_node = child - break - if not name_node: - return - func_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - line = node.start_point[0] + 1 - if parent_class_nid: - func_nid = _make_id(parent_class_nid, func_name) - add_node(func_nid, f".{func_name}()", line) - add_edge_raw(parent_class_nid, func_nid, "method", line) - else: - func_nid = _make_id(stem, func_name) - add_node(func_nid, f"{func_name}()", line) - add_edge_raw(file_nid, func_nid, "contains", line) - body = node.child_by_field_name("body") - if body is None: - for child in node.children: - if child.type == "function_body": - body = child - break - if body: - function_bodies.append((func_nid, body)) + if t == "use_declaration": + arg = node.child_by_field_name("argument") + if arg: + raw = _read_text(arg, source) + clean = raw.split("{")[0].rstrip(":").rstrip("*").rstrip(":") + module_name = clean.split("::")[-1].strip() + if module_name: + tgt_nid = _make_id(module_name) + add_edge(file_nid, tgt_nid, "imports_from", node.start_point[0] + 1) return for child in node.children: - walk(child, parent_class_nid=None) + walk(child, parent_impl_nid=None) walk(root) @@ -1715,21 +2162,22 @@ def walk(node, parent_class_nid: str | None = None) -> None: seen_call_pairs: set[tuple[str, str]] = set() def walk_calls(node, caller_nid: str) -> None: - if node.type == "function_declaration": + if node.type == "function_item": return if node.type == "call_expression": + func_node = node.child_by_field_name("function") callee_name: str | None = None - # Try first child (the callable) then look for simple_identifier - first = node.children[0] if node.children else None - if first: - if first.type == "simple_identifier": - callee_name = source[first.start_byte:first.end_byte].decode("utf-8", errors="replace") - elif first.type == "navigation_expression": - # obj.method - get the suffix - for child in reversed(first.children): - if child.type == "simple_identifier": - callee_name = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") - break + if func_node: + if func_node.type == "identifier": + callee_name = _read_text(func_node, source) + elif func_node.type == "field_expression": + field = func_node.child_by_field_name("field") + if field: + callee_name = _read_text(field, source) + elif func_node.type == "scoped_identifier": + name = func_node.child_by_field_name("name") + if name: + callee_name = _read_text(name, source) if callee_name: tgt_nid = label_to_nid.get(callee_name.lower()) if tgt_nid and tgt_nid != caller_nid: @@ -1741,10 +2189,10 @@ def walk_calls(node, caller_nid: str) -> None: "source": caller_nid, "target": tgt_nid, "relation": "calls", - "confidence": "INFERRED", + "confidence": "EXTRACTED", "source_file": str_path, "source_location": f"L{line}", - "weight": 0.8, + "weight": 1.0, }) for child in node.children: walk_calls(child, caller_nid) @@ -1762,16 +2210,18 @@ def walk_calls(node, caller_nid: str) -> None: return {"nodes": nodes, "edges": clean_edges} -def extract_scala(path: Path) -> dict: - """Extract classes, objects, functions, and imports from a .scala file.""" +# ── Zig ─────────────────────────────────────────────────────────────────────── + +def extract_zig(path: Path) -> dict: + """Extract functions, structs, enums, unions, and imports from a .zig file.""" try: - import tree_sitter_scala as tsscala + import tree_sitter_zig as tszig from tree_sitter import Language, Parser except ImportError: - return {"nodes": [], "edges": [], "error": "tree-sitter-scala not installed"} + return {"nodes": [], "edges": [], "error": "tree_sitter_zig not installed"} try: - language = Language(tsscala.language()) + language = Language(tszig.language()) parser = Parser(language) source = path.read_bytes() tree = parser.parse(source) @@ -1784,175 +2234,149 @@ def extract_scala(path: Path) -> dict: nodes: list[dict] = [] edges: list[dict] = [] seen_ids: set[str] = set() + function_bodies: list[tuple[str, Any]] = [] def add_node(nid: str, label: str, line: int) -> None: if nid not in seen_ids: seen_ids.add(nid) - nodes.append({ - "id": nid, - "label": label, - "file_type": "code", - "source_file": str_path, - "source_location": f"L{line}", - }) + nodes.append({"id": nid, "label": label, "file_type": "code", + "source_file": str_path, "source_location": f"L{line}"}) - def add_edge_raw(src: str, tgt: str, relation: str, line: int, confidence: str = "EXTRACTED", weight: float = 1.0) -> None: - edges.append({ - "source": src, - "target": tgt, - "relation": relation, - "confidence": confidence, - "source_file": str_path, - "source_location": f"L{line}", - "weight": weight, - }) + def add_edge(src: str, tgt: str, relation: str, line: int, + confidence: str = "EXTRACTED", weight: float = 1.0) -> None: + edges.append({"source": src, "target": tgt, "relation": relation, + "confidence": confidence, "source_file": str_path, + "source_location": f"L{line}", "weight": weight}) - file_nid = _make_id(stem) + file_nid = _make_id(str(path)) add_node(file_nid, path.name, 1) - function_bodies: list[tuple[str, object]] = [] + def _extract_import(node) -> None: + for child in node.children: + if child.type == "builtin_function": + bi = None + args = None + for c in child.children: + if c.type == "builtin_identifier": + bi = _read_text(c, source) + elif c.type == "arguments": + args = c + if bi in ("@import", "@cImport") and args: + for arg in args.children: + if arg.type in ("string_literal", "string"): + raw = _read_text(arg, source).strip('"') + module_name = raw.split("/")[-1].split(".")[0] + if module_name: + tgt_nid = _make_id(module_name) + add_edge(file_nid, tgt_nid, "imports_from", + node.start_point[0] + 1) + return + elif child.type == "field_expression": + _extract_import(child) + return - def walk(node, parent_class_nid: str | None = None) -> None: + def walk(node, parent_struct_nid: str | None = None) -> None: t = node.type - if t == "import_declaration": - for child in node.children: - if child.type in ("stable_id", "identifier"): - raw = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") - module_name = raw.split(".")[-1].strip("{} ") - if module_name and module_name != "_": - tgt_nid = _make_id(module_name) - add_edge_raw(file_nid, tgt_nid, "imports", node.start_point[0] + 1) - break + if t == "function_declaration": + name_node = node.child_by_field_name("name") + if name_node: + func_name = _read_text(name_node, source) + line = node.start_point[0] + 1 + if parent_struct_nid: + func_nid = _make_id(parent_struct_nid, func_name) + add_node(func_nid, f".{func_name}()", line) + add_edge(parent_struct_nid, func_nid, "method", line) + else: + func_nid = _make_id(stem, func_name) + add_node(func_nid, f"{func_name}()", line) + add_edge(file_nid, func_nid, "contains", line) + body = node.child_by_field_name("body") + if body: + function_bodies.append((func_nid, body)) return - if t in ("class_definition", "object_definition"): - name_node = node.child_by_field_name("name") - if name_node is None: - for child in node.children: - if child.type == "identifier": - name_node = child - break - if not name_node: + if t == "variable_declaration": + name_node = None + value_node = None + for child in node.children: + if child.type == "identifier": + name_node = child + elif child.type in ("struct_declaration", "enum_declaration", + "union_declaration", "builtin_function", + "field_expression"): + value_node = child + + if value_node and value_node.type == "struct_declaration": + if name_node: + struct_name = _read_text(name_node, source) + line = node.start_point[0] + 1 + struct_nid = _make_id(stem, struct_name) + add_node(struct_nid, struct_name, line) + add_edge(file_nid, struct_nid, "contains", line) + for child in value_node.children: + walk(child, parent_struct_nid=struct_nid) return - class_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - class_nid = _make_id(stem, class_name) - line = node.start_point[0] + 1 - add_node(class_nid, class_name, line) - add_edge_raw(file_nid, class_nid, "contains", line) - body = node.child_by_field_name("body") - if body is None: - for child in node.children: - if child.type == "template_body": - body = child - break - if body: - for child in body.children: - walk(child, parent_class_nid=class_nid) - return - if t == "function_definition": - name_node = node.child_by_field_name("name") - if name_node is None: - for child in node.children: - if child.type == "identifier": - name_node = child - break - if not name_node: + if value_node and value_node.type in ("enum_declaration", "union_declaration"): + if name_node: + type_name = _read_text(name_node, source) + line = node.start_point[0] + 1 + type_nid = _make_id(stem, type_name) + add_node(type_nid, type_name, line) + add_edge(file_nid, type_nid, "contains", line) return - func_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - line = node.start_point[0] + 1 - if parent_class_nid: - func_nid = _make_id(parent_class_nid, func_name) - add_node(func_nid, f".{func_name}()", line) - add_edge_raw(parent_class_nid, func_nid, "method", line) - else: - func_nid = _make_id(stem, func_name) - add_node(func_nid, f"{func_name}()", line) - add_edge_raw(file_nid, func_nid, "contains", line) - body = node.child_by_field_name("body") - if body: - function_bodies.append((func_nid, body)) + + if value_node and value_node.type in ("builtin_function", "field_expression"): + _extract_import(node) return for child in node.children: - walk(child, parent_class_nid=None) + walk(child, parent_struct_nid) walk(root) - label_to_nid: dict[str, str] = {} - for n in nodes: - raw = n["label"] - normalised = raw.strip("()").lstrip(".") - label_to_nid[normalised.lower()] = n["id"] - seen_call_pairs: set[tuple[str, str]] = set() def walk_calls(node, caller_nid: str) -> None: - if node.type == "function_definition": + if node.type == "function_declaration": return if node.type == "call_expression": - callee_name: str | None = None - # First child is the function being called - first = node.children[0] if node.children else None - if first: - if first.type == "identifier": - callee_name = source[first.start_byte:first.end_byte].decode("utf-8", errors="replace") - elif first.type == "field_expression": - field = first.child_by_field_name("field") - if field: - callee_name = source[field.start_byte:field.end_byte].decode("utf-8", errors="replace") - else: - for child in reversed(first.children): - if child.type == "identifier": - callee_name = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") - break - if callee_name: - tgt_nid = label_to_nid.get(callee_name.lower()) + fn = node.child_by_field_name("function") + if fn: + callee = _read_text(fn, source).split(".")[-1] + tgt_nid = next((n["id"] for n in nodes if n["label"] in + (f"{callee}()", f".{callee}()")), None) if tgt_nid and tgt_nid != caller_nid: pair = (caller_nid, tgt_nid) if pair not in seen_call_pairs: seen_call_pairs.add(pair) - line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "calls", - "confidence": "INFERRED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 0.8, - }) + add_edge(caller_nid, tgt_nid, "calls", + node.start_point[0] + 1, + confidence="EXTRACTED", weight=1.0) for child in node.children: walk_calls(child, caller_nid) for caller_nid, body_node in function_bodies: walk_calls(body_node, caller_nid) - valid_ids = seen_ids - clean_edges = [] - for edge in edges: - src, tgt = edge["source"], edge["target"] - if src in valid_ids and (tgt in valid_ids or edge["relation"] in ("imports", "imports_from")): - clean_edges.append(edge) - + clean_edges = [e for e in edges if e["source"] in seen_ids and + (e["target"] in seen_ids or e["relation"] == "imports_from")] return {"nodes": nodes, "edges": clean_edges} -def extract_php(path: Path) -> dict: - """Extract classes, functions, methods, namespace uses, and calls from a .php file.""" +# ── PowerShell ──────────────────────────────────────────────────────────────── + +def extract_powershell(path: Path) -> dict: + """Extract functions, classes, methods, and using statements from a .ps1 file.""" try: - import tree_sitter_php as tsphp + import tree_sitter_powershell as tsps from tree_sitter import Language, Parser - try: - from tree_sitter_php import language_php - language = Language(language_php()) - except (ImportError, AttributeError): - language = Language(tsphp.language()) except ImportError: - return {"nodes": [], "edges": [], "error": "tree-sitter-php not installed"} + return {"nodes": [], "edges": [], "error": "tree_sitter_powershell not installed"} try: + language = Language(tsps.language()) parser = Parser(language) source = path.read_bytes() tree = parser.parse(source) @@ -1965,161 +2389,140 @@ def extract_php(path: Path) -> dict: nodes: list[dict] = [] edges: list[dict] = [] seen_ids: set[str] = set() + function_bodies: list[tuple[str, Any]] = [] def add_node(nid: str, label: str, line: int) -> None: if nid not in seen_ids: seen_ids.add(nid) - nodes.append({ - "id": nid, - "label": label, - "file_type": "code", - "source_file": str_path, - "source_location": f"L{line}", - }) + nodes.append({"id": nid, "label": label, "file_type": "code", + "source_file": str_path, "source_location": f"L{line}"}) - def add_edge_raw(src: str, tgt: str, relation: str, line: int, confidence: str = "EXTRACTED", weight: float = 1.0) -> None: - edges.append({ - "source": src, - "target": tgt, - "relation": relation, - "confidence": confidence, - "source_file": str_path, - "source_location": f"L{line}", - "weight": weight, - }) + def add_edge(src: str, tgt: str, relation: str, line: int, + confidence: str = "EXTRACTED", weight: float = 1.0) -> None: + edges.append({"source": src, "target": tgt, "relation": relation, + "confidence": confidence, "source_file": str_path, + "source_location": f"L{line}", "weight": weight}) - file_nid = _make_id(stem) + file_nid = _make_id(str(path)) add_node(file_nid, path.name, 1) - function_bodies: list[tuple[str, object]] = [] - - def walk(node, parent_class_nid: str | None = None) -> None: - t = node.type - - if t == "namespace_use_clause": - for child in node.children: - if child.type in ("qualified_name", "name", "identifier"): - raw = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") - module_name = raw.split("\\")[-1].strip() - if module_name: - tgt_nid = _make_id(module_name) - add_edge_raw(file_nid, tgt_nid, "imports", node.start_point[0] + 1) - break - return - - if t == "class_declaration": - name_node = node.child_by_field_name("name") - if name_node is None: - for child in node.children: - if child.type == "name": - name_node = child - break - if not name_node: - return - class_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - class_nid = _make_id(stem, class_name) - line = node.start_point[0] + 1 - add_node(class_nid, class_name, line) - add_edge_raw(file_nid, class_nid, "contains", line) - body = node.child_by_field_name("body") - if body is None: - for child in node.children: - if child.type == "declaration_list": - body = child - break - if body: - for child in body.children: - walk(child, parent_class_nid=class_nid) - return + _PS_SKIP = frozenset({ + "using", "return", "if", "else", "elseif", "foreach", "for", + "while", "do", "switch", "try", "catch", "finally", "throw", + "break", "continue", "exit", "param", "begin", "process", "end", + }) - if t in ("function_definition", "method_declaration"): - name_node = node.child_by_field_name("name") - if name_node is None: - for child in node.children: - if child.type == "name": - name_node = child - break - if not name_node: - return - func_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - line = node.start_point[0] + 1 - if parent_class_nid: - func_nid = _make_id(parent_class_nid, func_name) - add_node(func_nid, f".{func_name}()", line) - add_edge_raw(parent_class_nid, func_nid, "method", line) - else: + def _find_script_block_body(node): + for child in node.children: + if child.type == "script_block": + for sc in child.children: + if sc.type == "script_block_body": + return sc + return child + return None + + def walk(node, parent_class_nid: str | None = None) -> None: + t = node.type + + if t == "function_statement": + name_node = next((c for c in node.children if c.type == "function_name"), None) + if name_node: + func_name = _read_text(name_node, source) + line = node.start_point[0] + 1 func_nid = _make_id(stem, func_name) add_node(func_nid, f"{func_name}()", line) - add_edge_raw(file_nid, func_nid, "contains", line) - body = node.child_by_field_name("body") - if body is None: + add_edge(file_nid, func_nid, "contains", line) + body = _find_script_block_body(node) + if body: + function_bodies.append((func_nid, body)) + return + + if t == "class_statement": + name_node = next((c for c in node.children if c.type == "simple_name"), None) + if name_node: + class_name = _read_text(name_node, source) + line = node.start_point[0] + 1 + class_nid = _make_id(stem, class_name) + add_node(class_nid, class_name, line) + add_edge(file_nid, class_nid, "contains", line) for child in node.children: - if child.type == "compound_statement": - body = child - break - if body: - function_bodies.append((func_nid, body)) + walk(child, parent_class_nid=class_nid) + return + + if t == "class_method_definition": + name_node = next((c for c in node.children if c.type == "simple_name"), None) + if name_node: + method_name = _read_text(name_node, source) + line = node.start_point[0] + 1 + if parent_class_nid: + method_nid = _make_id(parent_class_nid, method_name) + add_node(method_nid, f".{method_name}()", line) + add_edge(parent_class_nid, method_nid, "method", line) + else: + method_nid = _make_id(stem, method_name) + add_node(method_nid, f"{method_name}()", line) + add_edge(file_nid, method_nid, "contains", line) + body = _find_script_block_body(node) + if body: + function_bodies.append((method_nid, body)) + return + + if t == "command": + cmd_name_node = next((c for c in node.children if c.type == "command_name"), None) + if cmd_name_node: + cmd_text = _read_text(cmd_name_node, source).lower() + if cmd_text == "using": + tokens = [] + for child in node.children: + if child.type == "command_elements": + for el in child.children: + if el.type == "generic_token": + tokens.append(_read_text(el, source)) + module_tokens = [t for t in tokens + if t.lower() not in ("namespace", "module", "assembly")] + if module_tokens: + module_name = module_tokens[-1].split(".")[-1] + add_edge(file_nid, _make_id(module_name), "imports_from", + node.start_point[0] + 1) return for child in node.children: - walk(child, parent_class_nid=None) + walk(child, parent_class_nid) walk(root) - label_to_nid: dict[str, str] = {} - for n in nodes: - raw = n["label"] - normalised = raw.strip("()").lstrip(".") - label_to_nid[normalised.lower()] = n["id"] - + label_to_nid = {n["label"].strip("()").lstrip(".").lower(): n["id"] for n in nodes} seen_call_pairs: set[tuple[str, str]] = set() def walk_calls(node, caller_nid: str) -> None: - if node.type in ("function_definition", "method_declaration"): + if node.type in ("function_statement", "class_statement"): return - if node.type in ("function_call_expression", "member_call_expression"): - callee_name: str | None = None - if node.type == "function_call_expression": - func_node = node.child_by_field_name("function") - if func_node: - callee_name = source[func_node.start_byte:func_node.end_byte].decode("utf-8", errors="replace") - else: - # member_call_expression: obj->method(args) - name_node = node.child_by_field_name("name") - if name_node: - callee_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") - if callee_name: - tgt_nid = label_to_nid.get(callee_name.lower()) - if tgt_nid and tgt_nid != caller_nid: - pair = (caller_nid, tgt_nid) - if pair not in seen_call_pairs: - seen_call_pairs.add(pair) - line = node.start_point[0] + 1 - edges.append({ - "source": caller_nid, - "target": tgt_nid, - "relation": "calls", - "confidence": "INFERRED", - "source_file": str_path, - "source_location": f"L{line}", - "weight": 0.8, - }) + if node.type == "command": + cmd_name_node = next((c for c in node.children if c.type == "command_name"), None) + if cmd_name_node: + cmd_text = _read_text(cmd_name_node, source) + if cmd_text.lower() not in _PS_SKIP: + tgt_nid = label_to_nid.get(cmd_text.lower()) + if tgt_nid and tgt_nid != caller_nid: + pair = (caller_nid, tgt_nid) + if pair not in seen_call_pairs: + seen_call_pairs.add(pair) + add_edge(caller_nid, tgt_nid, "calls", + node.start_point[0] + 1, + confidence="EXTRACTED", weight=1.0) for child in node.children: walk_calls(child, caller_nid) for caller_nid, body_node in function_bodies: walk_calls(body_node, caller_nid) - valid_ids = seen_ids - clean_edges = [] - for edge in edges: - src, tgt = edge["source"], edge["target"] - if src in valid_ids and (tgt in valid_ids or edge["relation"] in ("imports", "imports_from")): - clean_edges.append(edge) - + clean_edges = [e for e in edges if e["source"] in seen_ids and + (e["target"] in seen_ids or e["relation"] == "imports_from")] return {"nodes": nodes, "edges": clean_edges} +# ── Cross-file import resolution ────────────────────────────────────────────── + def _resolve_cross_file_imports( per_file: list[dict], paths: list[Path], @@ -2253,14 +2656,417 @@ def walk_imports(node) -> None: return new_edges -def extract(paths: list[Path]) -> dict: +def extract_objc(path: Path) -> dict: + """Extract interfaces, implementations, protocols, methods, and imports from .m/.mm/.h files.""" + try: + import tree_sitter_objc as tsobjc + from tree_sitter import Language, Parser + except ImportError: + return {"nodes": [], "edges": [], "error": "tree_sitter_objc not installed"} + + try: + language = Language(tsobjc.language()) + parser = Parser(language) + source = path.read_bytes() + tree = parser.parse(source) + root = tree.root_node + except Exception as e: + return {"nodes": [], "edges": [], "error": str(e)} + + stem = path.stem + str_path = str(path) + nodes: list[dict] = [] + edges: list[dict] = [] + seen_ids: set[str] = set() + method_bodies: list[tuple[str, Any]] = [] + + def add_node(nid: str, label: str, line: int) -> None: + if nid not in seen_ids: + seen_ids.add(nid) + nodes.append({"id": nid, "label": label, "file_type": "code", + "source_file": str_path, "source_location": f"L{line}"}) + + def add_edge(src: str, tgt: str, relation: str, line: int, + confidence: str = "EXTRACTED", weight: float = 1.0) -> None: + edges.append({"source": src, "target": tgt, "relation": relation, + "confidence": confidence, "source_file": str_path, + "source_location": f"L{line}", "weight": weight}) + + file_nid = _make_id(str(path)) + add_node(file_nid, path.name, 1) + + def _read(node) -> str: + return source[node.start_byte:node.end_byte].decode("utf-8", errors="replace") + + def _get_name(node, field: str) -> str | None: + n = node.child_by_field_name(field) + return _read(n) if n else None + + def walk(node, parent_nid: str | None = None) -> None: + t = node.type + line = node.start_point[0] + 1 + + if t == "preproc_include": + # #import or #import "MyClass.h" + for child in node.children: + if child.type == "system_lib_string": + raw = _read(child).strip("<>") + module = raw.split("/")[-1].replace(".h", "") + if module: + tgt_nid = _make_id(module) + add_edge(file_nid, tgt_nid, "imports", line) + elif child.type == "string_literal": + # recurse into string_literal to find string_content + for sub in child.children: + if sub.type == "string_content": + raw = _read(sub) + module = raw.split("/")[-1].replace(".h", "") + if module: + tgt_nid = _make_id(module) + add_edge(file_nid, tgt_nid, "imports", line) + return + + if t == "class_interface": + # @interface ClassName : SuperClass + # children: @interface, identifier(name), ':', identifier(super), parameterized_arguments, ... + identifiers = [c for c in node.children if c.type == "identifier"] + if not identifiers: + for child in node.children: + walk(child, parent_nid) + return + name = _read(identifiers[0]) + cls_nid = _make_id(stem, name) + add_node(cls_nid, name, line) + add_edge(file_nid, cls_nid, "contains", line) + # superclass is second identifier after ':' + colon_seen = False + for child in node.children: + if child.type == ":": + colon_seen = True + elif colon_seen and child.type == "identifier": + super_nid = _make_id(_read(child)) + add_edge(cls_nid, super_nid, "inherits", line) + colon_seen = False + elif child.type == "parameterized_arguments": + # protocols adopted + for sub in child.children: + if sub.type == "type_name": + for s in sub.children: + if s.type == "type_identifier": + proto_nid = _make_id(_read(s)) + add_edge(cls_nid, proto_nid, "imports", line) + elif child.type == "method_declaration": + walk(child, cls_nid) + return + + if t == "class_implementation": + # @implementation ClassName + name = None + for child in node.children: + if child.type == "identifier": + name = _read(child) + break + if not name: + for child in node.children: + walk(child, parent_nid) + return + impl_nid = _make_id(stem, name) + if impl_nid not in seen_ids: + add_node(impl_nid, name, line) + add_edge(file_nid, impl_nid, "contains", line) + for child in node.children: + if child.type == "implementation_definition": + for sub in child.children: + walk(sub, impl_nid) + return + + if t == "protocol_declaration": + name = None + for child in node.children: + if child.type == "identifier": + name = _read(child) + break + if name: + proto_nid = _make_id(stem, name) + add_node(proto_nid, f"<{name}>", line) + add_edge(file_nid, proto_nid, "contains", line) + for child in node.children: + walk(child, proto_nid) + return + + if t in ("method_declaration", "method_definition"): + container = parent_nid or file_nid + # method name is the first identifier child (simple selector) + # for compound selectors: identifier + method_parameter pairs + parts = [] + for child in node.children: + if child.type == "identifier": + parts.append(_read(child)) + elif child.type == "method_parameter": + for sub in child.children: + if sub.type == "identifier": + # selector keyword before ':' + pass + method_name = "".join(parts) if parts else None + if method_name: + method_nid = _make_id(container, method_name) + add_node(method_nid, f"-{method_name}", line) + add_edge(container, method_nid, "method", line) + if t == "method_definition": + method_bodies.append((method_nid, node)) + return + + for child in node.children: + walk(child, parent_nid) + + walk(root) + + # Second pass: resolve calls inside method bodies + all_method_nids = {n["id"] for n in nodes if n["id"] != file_nid} + seen_calls: set[tuple[str, str]] = set() + for caller_nid, body_node in method_bodies: + def walk_calls(n) -> None: + if n.type == "message_expression": + # [receiver selector] + for child in n.children: + if child.type in ("selector", "keyword_argument_list"): + sel = [] + if child.type == "selector": + sel.append(_read(child)) + else: + for sub in child.children: + if sub.type == "keyword_argument": + for s in sub.children: + if s.type == "selector": + sel.append(_read(s)) + method_name = "".join(sel) + for candidate in all_method_nids: + if candidate.endswith(_make_id("", method_name).lstrip("_")): + pair = (caller_nid, candidate) + if pair not in seen_calls and caller_nid != candidate: + seen_calls.add(pair) + add_edge(caller_nid, candidate, "calls", body_node.start_point[0] + 1, + confidence="EXTRACTED", weight=1.0) + for child in n.children: + walk_calls(child) + walk_calls(body_node) + + return {"nodes": nodes, "edges": edges, "input_tokens": 0, "output_tokens": 0} + + +def extract_elixir(path: Path) -> dict: + """Extract modules, functions, imports, and calls from a .ex/.exs file.""" + try: + import tree_sitter_elixir as tselixir + from tree_sitter import Language, Parser + except ImportError: + return {"nodes": [], "edges": [], "error": "tree_sitter_elixir not installed"} + + try: + language = Language(tselixir.language()) + parser = Parser(language) + source = path.read_bytes() + tree = parser.parse(source) + root = tree.root_node + except Exception as e: + return {"nodes": [], "edges": [], "error": str(e)} + + stem = path.stem + str_path = str(path) + nodes: list[dict] = [] + edges: list[dict] = [] + seen_ids: set[str] = set() + function_bodies: list[tuple[str, Any]] = [] + + def add_node(nid: str, label: str, line: int) -> None: + if nid not in seen_ids: + seen_ids.add(nid) + nodes.append({"id": nid, "label": label, "file_type": "code", + "source_file": str_path, "source_location": f"L{line}"}) + + def add_edge(src: str, tgt: str, relation: str, line: int, + confidence: str = "EXTRACTED", weight: float = 1.0) -> None: + edges.append({"source": src, "target": tgt, "relation": relation, + "confidence": confidence, "source_file": str_path, + "source_location": f"L{line}", "weight": weight}) + + file_nid = _make_id(str(path)) + add_node(file_nid, path.name, 1) + + _IMPORT_KEYWORDS = frozenset({"alias", "import", "require", "use"}) + + def _get_alias_text(node) -> str | None: + for child in node.children: + if child.type == "alias": + return source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") + return None + + def walk(node, parent_module_nid: str | None = None) -> None: + if node.type != "call": + for child in node.children: + walk(child, parent_module_nid) + return + + identifier_node = None + arguments_node = None + do_block_node = None + for child in node.children: + if child.type == "identifier": + identifier_node = child + elif child.type == "arguments": + arguments_node = child + elif child.type == "do_block": + do_block_node = child + + if identifier_node is None: + for child in node.children: + walk(child, parent_module_nid) + return + + keyword = source[identifier_node.start_byte:identifier_node.end_byte].decode("utf-8", errors="replace") + line = node.start_point[0] + 1 + + if keyword == "defmodule": + module_name = _get_alias_text(arguments_node) if arguments_node else None + if not module_name: + return + module_nid = _make_id(stem, module_name) + add_node(module_nid, module_name, line) + add_edge(file_nid, module_nid, "contains", line) + if do_block_node: + for child in do_block_node.children: + walk(child, parent_module_nid=module_nid) + return + + if keyword in ("def", "defp"): + func_name = None + if arguments_node: + for child in arguments_node.children: + if child.type == "call": + for sub in child.children: + if sub.type == "identifier": + func_name = source[sub.start_byte:sub.end_byte].decode("utf-8", errors="replace") + break + elif child.type == "identifier": + func_name = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") + break + if not func_name: + return + container = parent_module_nid or file_nid + func_nid = _make_id(container, func_name) + add_node(func_nid, f"{func_name}()", line) + if parent_module_nid: + add_edge(parent_module_nid, func_nid, "method", line) + else: + add_edge(file_nid, func_nid, "contains", line) + if do_block_node: + function_bodies.append((func_nid, do_block_node)) + return + + if keyword in _IMPORT_KEYWORDS and arguments_node: + module_name = _get_alias_text(arguments_node) + if module_name: + tgt_nid = _make_id(module_name) + add_edge(file_nid, tgt_nid, "imports", line) + return + + for child in node.children: + walk(child, parent_module_nid) + + walk(root) + + label_to_nid: dict[str, str] = {} + for n in nodes: + normalised = n["label"].strip("()").lstrip(".") + label_to_nid[normalised.lower()] = n["id"] + + seen_call_pairs: set[tuple[str, str]] = set() + _SKIP_KEYWORDS = frozenset({ + "def", "defp", "defmodule", "defmacro", "defmacrop", + "defstruct", "defprotocol", "defimpl", "defguard", + "alias", "import", "require", "use", + "if", "unless", "case", "cond", "with", "for", + }) + + def walk_calls(node, caller_nid: str) -> None: + if node.type != "call": + for child in node.children: + walk_calls(child, caller_nid) + return + for child in node.children: + if child.type == "identifier": + kw = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") + if kw in _SKIP_KEYWORDS: + for c in node.children: + walk_calls(c, caller_nid) + return + break + callee_name: str | None = None + for child in node.children: + if child.type == "dot": + dot_text = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") + parts = dot_text.rstrip(".").split(".") + if parts: + callee_name = parts[-1] + break + if child.type == "identifier": + callee_name = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") + break + if callee_name: + tgt_nid = label_to_nid.get(callee_name.lower()) + if tgt_nid and tgt_nid != caller_nid: + pair = (caller_nid, tgt_nid) + if pair not in seen_call_pairs: + seen_call_pairs.add(pair) + add_edge(caller_nid, tgt_nid, "calls", + node.start_point[0] + 1, confidence="EXTRACTED", weight=1.0) + for child in node.children: + walk_calls(child, caller_nid) + + for caller_nid, body in function_bodies: + walk_calls(body, caller_nid) + + clean_edges = [e for e in edges if e["source"] in seen_ids and + (e["target"] in seen_ids or e["relation"] == "imports")] + return {"nodes": nodes, "edges": clean_edges, "input_tokens": 0, "output_tokens": 0} + + +# ── Main extract and collect_files ──────────────────────────────────────────── + + +def _check_tree_sitter_version() -> None: + """Raise a clear error if tree-sitter is too old for the new Language API.""" + try: + from tree_sitter import LANGUAGE_VERSION + except ImportError: + raise ImportError( + "tree-sitter is not installed. Run: pip install 'tree-sitter>=0.23.0'" + ) + # Language API v2 starts at LANGUAGE_VERSION 14 + if LANGUAGE_VERSION < 14: + import tree_sitter as _ts + raise RuntimeError( + f"tree-sitter {getattr(_ts, '__version__', 'unknown')} is too old. " + f"graphify requires tree-sitter >= 0.23.0 (Language API v2). " + f"Run: pip install --upgrade tree-sitter" + ) + + +def extract(paths: list[Path], cache_root: Path | None = None) -> dict: """Extract AST nodes and edges from a list of code files. Two-pass process: 1. Per-file structural extraction (classes, functions, imports) 2. Cross-file import resolution: turns file-level imports into class-level INFERRED edges (DigestAuth --uses--> Response) + + Args: + paths: files to extract from + cache_root: explicit root for graphify-out/cache/ (overrides the + inferred common path prefix). Pass Path('.') when running on a + subdirectory so the cache stays at ./graphify-out/cache/. """ + _check_tree_sitter_version() per_file: list[dict] = [] # Infer a common root for cache keys @@ -2278,117 +3084,67 @@ def extract(paths: list[Path]) -> dict: except Exception: root = Path(".") - _JS_SUFFIXES = {".js", ".ts", ".tsx"} + _DISPATCH: dict[str, Any] = { + ".py": extract_python, + ".js": extract_js, + ".jsx": extract_js, + ".mjs": extract_js, + ".ts": extract_js, + ".tsx": extract_js, + ".go": extract_go, + ".rs": extract_rust, + ".java": extract_java, + ".c": extract_c, + ".h": extract_c, + ".cpp": extract_cpp, + ".cc": extract_cpp, + ".cxx": extract_cpp, + ".hpp": extract_cpp, + ".rb": extract_ruby, + ".cs": extract_csharp, + ".kt": extract_kotlin, + ".kts": extract_kotlin, + ".scala": extract_scala, + ".php": extract_php, + ".swift": extract_swift, + ".lua": extract_lua, + ".toc": extract_lua, + ".zig": extract_zig, + ".ps1": extract_powershell, + ".ex": extract_elixir, + ".exs": extract_elixir, + ".m": extract_objc, + ".mm": extract_objc, + ".jl": extract_julia, + ".vue": extract_js, + ".svelte": extract_js, + ".dart": extract_dart, + ".v": extract_verilog, + ".sv": extract_verilog, + } - for path in paths: - if path.suffix == ".py": - cached = load_cached(path, root) - if cached is not None: - per_file.append(cached) - continue - result = extract_python(path) - if "error" not in result: - save_cached(path, result, root) - per_file.append(result) - elif path.suffix in _JS_SUFFIXES: - cached = load_cached(path, root) - if cached is not None: - per_file.append(cached) - continue - result = extract_js(path) - if "error" not in result: - save_cached(path, result, root) - per_file.append(result) - elif path.suffix == ".go": - cached = load_cached(path, root) - if cached is not None: - per_file.append(cached) - continue - result = extract_go(path) - if "error" not in result: - save_cached(path, result, root) - per_file.append(result) - elif path.suffix == ".rs": - cached = load_cached(path, root) - if cached is not None: - per_file.append(cached) - continue - result = extract_rust(path) - if "error" not in result: - save_cached(path, result, root) - per_file.append(result) - elif path.suffix == ".java": - cached = load_cached(path, root) - if cached is not None: - per_file.append(cached) - continue - result = extract_java(path) - if "error" not in result: - save_cached(path, result, root) - per_file.append(result) - elif path.suffix in {".c", ".h"}: - cached = load_cached(path, root) - if cached is not None: - per_file.append(cached) - continue - result = extract_c(path) - if "error" not in result: - save_cached(path, result, root) - per_file.append(result) - elif path.suffix in {".cpp", ".cc", ".cxx", ".hpp"}: - cached = load_cached(path, root) - if cached is not None: - per_file.append(cached) - continue - result = extract_cpp(path) - if "error" not in result: - save_cached(path, result, root) - per_file.append(result) - elif path.suffix == ".rb": - cached = load_cached(path, root) - if cached is not None: - per_file.append(cached) - continue - result = extract_ruby(path) - if "error" not in result: - save_cached(path, result, root) - per_file.append(result) - elif path.suffix == ".cs": - cached = load_cached(path, root) - if cached is not None: - per_file.append(cached) - continue - result = extract_csharp(path) - if "error" not in result: - save_cached(path, result, root) - per_file.append(result) - elif path.suffix in {".kt", ".kts"}: - cached = load_cached(path, root) - if cached is not None: - per_file.append(cached) - continue - result = extract_kotlin(path) - if "error" not in result: - save_cached(path, result, root) - per_file.append(result) - elif path.suffix == ".scala": - cached = load_cached(path, root) - if cached is not None: - per_file.append(cached) - continue - result = extract_scala(path) - if "error" not in result: - save_cached(path, result, root) - per_file.append(result) - elif path.suffix == ".php": - cached = load_cached(path, root) - if cached is not None: - per_file.append(cached) - continue - result = extract_php(path) - if "error" not in result: - save_cached(path, result, root) - per_file.append(result) + total = len(paths) + _PROGRESS_INTERVAL = 100 + for i, path in enumerate(paths): + if total >= _PROGRESS_INTERVAL and i % _PROGRESS_INTERVAL == 0 and i > 0: + print(f" AST extraction: {i}/{total} files ({i * 100 // total}%)", flush=True) + # .blade.php must be checked before suffix lookup since Path.suffix returns .php + if path.name.endswith(".blade.php"): + extractor = extract_blade + else: + extractor = _DISPATCH.get(path.suffix) + if extractor is None: + continue + cached = load_cached(path, cache_root or root) + if cached is not None: + per_file.append(cached) + continue + result = extractor(path) + if "error" not in result: + save_cached(path, result, cache_root or root) + per_file.append(result) + if total >= _PROGRESS_INTERVAL: + print(f" AST extraction: {total}/{total} files (100%)", flush=True) all_nodes: list[dict] = [] all_edges: list[dict] = [] @@ -2398,9 +3154,45 @@ def extract(paths: list[Path]) -> dict: # Add cross-file class-level edges (Python only - uses Python parser internally) py_paths = [p for p in paths if p.suffix == ".py"] - py_results = [r for r, p in zip(per_file, paths) if p.suffix == ".py"] - cross_file_edges = _resolve_cross_file_imports(py_results, py_paths) - all_edges.extend(cross_file_edges) + if py_paths: + py_results = [r for r, p in zip(per_file, paths) if p.suffix == ".py"] + try: + cross_file_edges = _resolve_cross_file_imports(py_results, py_paths) + all_edges.extend(cross_file_edges) + except Exception as exc: + import logging + logging.getLogger(__name__).warning("Cross-file import resolution failed, skipping: %s", exc) + + # Cross-file call resolution for all languages + # Each extractor saved unresolved calls in raw_calls. Now that we have all + # nodes from all files, resolve any callee that exists in another file. + global_label_to_nid: dict[str, str] = {} + for n in all_nodes: + raw = n.get("label", "") + normalised = raw.strip("()").lstrip(".") + if normalised: + global_label_to_nid[normalised.lower()] = n["id"] + + existing_pairs = {(e["source"], e["target"]) for e in all_edges} + for result in per_file: + for rc in result.get("raw_calls", []): + callee = rc.get("callee", "") + if not callee: + continue + tgt = global_label_to_nid.get(callee.lower()) + caller = rc["caller_nid"] + if tgt and tgt != caller and (caller, tgt) not in existing_pairs: + existing_pairs.add((caller, tgt)) + all_edges.append({ + "source": caller, + "target": tgt, + "relation": "calls", + "confidence": "INFERRED", + "confidence_score": 0.8, + "source_file": rc.get("source_file", ""), + "source_location": rc.get("source_location"), + "weight": 1.0, + }) return { "nodes": all_nodes, @@ -2410,20 +3202,49 @@ def extract(paths: list[Path]) -> dict: } -def collect_files(target: Path) -> list[Path]: +def collect_files(target: Path, *, follow_symlinks: bool = False, root: Path | None = None) -> list[Path]: if target.is_file(): return [target] - _EXTENSIONS = ( - "*.py", "*.js", "*.ts", "*.tsx", "*.go", "*.rs", - "*.java", "*.c", "*.h", "*.cpp", "*.cc", "*.cxx", "*.hpp", - "*.rb", "*.cs", "*.kt", "*.kts", "*.scala", "*.php", - ) - results: list[Path] = [] - for pattern in _EXTENSIONS: - results.extend( - p for p in target.rglob(pattern) - if not any(part.startswith(".") for part in p.parts) - ) + _EXTENSIONS = { + ".py", ".js", ".ts", ".tsx", ".go", ".rs", + ".java", ".c", ".h", ".cpp", ".cc", ".cxx", ".hpp", + ".rb", ".cs", ".kt", ".kts", ".scala", ".php", ".swift", + ".lua", ".toc", ".zig", ".ps1", + ".m", ".mm", + } + from graphify.detect import _load_graphifyignore, _is_ignored + ignore_root = root if root is not None else target + patterns = _load_graphifyignore(ignore_root) + + def _ignored(p: Path) -> bool: + return bool(patterns and _is_ignored(p, ignore_root, patterns)) + + if not follow_symlinks: + results: list[Path] = [] + for ext in sorted(_EXTENSIONS): + results.extend( + p for p in target.rglob(f"*{ext}") + if not any(part.startswith(".") for part in p.parts) + and not _ignored(p) + ) + return sorted(results) + # Walk with symlink following + cycle detection + results = [] + for dirpath, dirnames, filenames in os.walk(target, followlinks=True): + if os.path.islink(dirpath): + real = os.path.realpath(dirpath) + parent_real = os.path.realpath(os.path.dirname(dirpath)) + if parent_real == real or parent_real.startswith(real + os.sep): + dirnames.clear() + continue + dp = Path(dirpath) + if any(part.startswith(".") for part in dp.parts): + dirnames.clear() + continue + for fname in filenames: + p = dp / fname + if p.suffix in _EXTENSIONS and not fname.startswith(".") and not _ignored(p): + results.append(p) return sorted(results)