diff --git a/graphify/extract.py b/graphify/extract.py index 333fa39a..7390c6ca 100644 --- a/graphify/extract.py +++ b/graphify/extract.py @@ -1447,56 +1447,255 @@ def extract_blade(path: Path) -> dict: def extract_dart(path: Path) -> dict: - """Extract classes, mixins, functions, imports, and calls from a .dart file using regex.""" + """Extract classes, mixins, enums, extensions, functions, imports, and calls from a .dart file using tree-sitter.""" try: - src = path.read_text(encoding="utf-8", errors="replace") - except OSError: - return {"error": f"cannot read {path}"} + import tree_sitter_dart_orchard as tsdart + from tree_sitter import Language, Parser + except ImportError: + return {"nodes": [], "edges": [], "error": "tree-sitter-dart-orchard not installed"} + + try: + language = Language(tsdart.language()) + parser = Parser(language) + source = path.read_bytes() + tree = parser.parse(source) + root = tree.root_node + except Exception as e: + return {"nodes": [], "edges": [], "error": str(e)} + + stem = path.stem + str_path = str(path) + nodes: list[dict] = [] + edges: list[dict] = [] + seen_ids: set[str] = set() + function_bodies: list[tuple[str, object]] = [] + + def add_node(nid: str, label: str, line: int) -> None: + if nid not in seen_ids: + seen_ids.add(nid) + nodes.append({ + "id": nid, + "label": label, + "file_type": "code", + "source_file": str_path, + "source_location": f"L{line}", + }) + + def add_edge(src: str, tgt: str, relation: str, line: int) -> None: + edges.append({ + "source": src, + "target": tgt, + "relation": relation, + "confidence": "EXTRACTED", + "source_file": str_path, + "source_location": f"L{line}", + "weight": 1.0, + }) file_nid = _make_id(str(path)) - nodes = [{"id": file_nid, "label": path.name, "file_type": "code", - "source_file": str(path), "source_location": None}] - edges = [] - defined: set[str] = set() + add_node(file_nid, path.name, 1) - # Classes and mixins - for m in re.finditer(r"^\s*(?:abstract\s+)?(?:class|mixin)\s+(\w+)", src, re.MULTILINE): - nid = _make_id(str(path), m.group(1)) - if nid not in defined: - nodes.append({"id": nid, "label": m.group(1), "file_type": "code", - "source_file": str(path), "source_location": None}) - edges.append({"source": file_nid, "target": nid, "relation": "defines", - "confidence": "EXTRACTED", "confidence_score": 1.0, - "source_file": str(path), "source_location": None, "weight": 1.0}) - defined.add(nid) - - # Top-level and member functions/methods - for m in re.finditer(r"^\s*(?:static\s+|async\s+)?(?:\w+\s+)+(\w+)\s*\(", src, re.MULTILINE): - name = m.group(1) - if name in {"if", "for", "while", "switch", "catch", "return"}: - continue - nid = _make_id(str(path), name) - if nid not in defined: - nodes.append({"id": nid, "label": name, "file_type": "code", - "source_file": str(path), "source_location": None}) - edges.append({"source": file_nid, "target": nid, "relation": "defines", - "confidence": "EXTRACTED", "confidence_score": 1.0, - "source_file": str(path), "source_location": None, "weight": 1.0}) - defined.add(nid) - - # import 'package:...' or import '...' - for m in re.finditer(r"""^import\s+['"]([^'"]+)['"]""", src, re.MULTILINE): - pkg = m.group(1) - tgt_nid = _make_id(pkg) - if tgt_nid not in defined: - nodes.append({"id": tgt_nid, "label": pkg, "file_type": "code", - "source_file": str(path), "source_location": None}) - defined.add(tgt_nid) - edges.append({"source": file_nid, "target": tgt_nid, "relation": "imports", - "confidence": "EXTRACTED", "confidence_score": 1.0, - "source_file": str(path), "source_location": None, "weight": 1.0}) + # ── Helpers ────────────────────────────────────────────────────────────── - return {"nodes": nodes, "edges": edges} + def _first_child_of_type(node, *types): + for child in node.children: + if child.type in types: + return child + return None + + def _find_import_uri(node) -> str | None: + """Recursively find the string_literal inside an import_or_export node.""" + if node.type == "string_literal": + return _read_text(node, source).strip("'\"") + for child in node.children: + result = _find_import_uri(child) + if result: + return result + return None + + def _process_import(node) -> None: + line = node.start_point[0] + 1 + uri = _find_import_uri(node) + if uri: + tgt_nid = _make_id(uri) + if tgt_nid not in seen_ids: + add_node(tgt_nid, uri, line) + add_edge(file_nid, tgt_nid, "imports", line) + + def _process_supertypes(node, class_nid: str) -> None: + line = node.start_point[0] + 1 + + def _link_supertype(type_node): + name = _read_text(type_node, source) + tgt = _make_id(stem, name) + if tgt not in seen_ids: + tgt = _make_id(name) + if tgt not in seen_ids: + add_node(tgt, name, line) + add_edge(class_nid, tgt, "inherits", line) + + sup = _first_child_of_type(node, "superclass") + if sup: + for child in sup.children: + if child.type == "type_identifier": + _link_supertype(child) + elif child.type in ("mixins", "interfaces"): + for sub in child.children: + if sub.type == "type_identifier": + _link_supertype(sub) + + def _extract_sig_name(sig_node) -> str | None: + """Extract name from method_signature, function_signature, or constructor_signature.""" + t = sig_node.type + if t == "method_signature": + inner = _first_child_of_type(sig_node, "function_signature", "constructor_signature", + "factory_constructor_signature", "getter_signature", "setter_signature") + return _extract_sig_name(inner) if inner else None + if t in ("function_signature", "getter_signature", "setter_signature", + "constructor_signature", "factory_constructor_signature"): + ident = _first_child_of_type(sig_node, "identifier") + return _read_text(ident, source) if ident else None + return None + + def _process_body_children(children_list, parent_nid: str) -> None: + """Pair method_signature/declaration with sibling function_body in class/mixin/enum/extension body.""" + + def _register_method(sig_node, idx: int) -> int: + """Register a method node and optionally consume the next sibling function_body. + Returns the number of extra children consumed (0 or 1).""" + name = _extract_sig_name(sig_node) + if not name: + return 0 + line = sig_node.start_point[0] + 1 + func_nid = _make_id(parent_nid, name) + add_node(func_nid, f".{name}()", line) + add_edge(parent_nid, func_nid, "method", line) + if idx + 1 < len(children_list) and children_list[idx + 1].type == "function_body": + function_bodies.append((func_nid, children_list[idx + 1])) + return 1 + return 0 + + i = 0 + while i < len(children_list): + child = children_list[i] + if child.type == "method_signature": + skip = _register_method(child, i) + i += 1 + skip + continue + elif child.type == "declaration": + sig = _first_child_of_type(child, "constructor_signature", "factory_constructor_signature", + "function_signature") + if sig: + skip = _register_method(sig, i) + i += 1 + skip + continue + i += 1 + + # ── Main walk ──────────────────────────────────────────────────────────── + + def _process_class_like(node, body_type: str) -> None: + ident = _first_child_of_type(node, "identifier") + if not ident: + return + name = _read_text(ident, source) + nid = _make_id(stem, name) + line = node.start_point[0] + 1 + add_node(nid, name, line) + add_edge(file_nid, nid, "contains", line) + _process_supertypes(node, nid) + body = _first_child_of_type(node, body_type) + if body: + _process_body_children(list(body.children), nid) + + # Walk root children with sibling pairing for top-level functions + children = list(root.children) + i = 0 + while i < len(children): + child = children[i] + t = child.type + + if t == "import_or_export": + _process_import(child) + elif t == "class_definition": + _process_class_like(child, "class_body") + elif t == "mixin_declaration": + _process_class_like(child, "class_body") + elif t == "enum_declaration": + _process_class_like(child, "enum_body") + elif t == "extension_declaration": + _process_class_like(child, "extension_body") + elif t == "function_signature": + ident = _first_child_of_type(child, "identifier") + if ident: + name = _read_text(ident, source) + func_nid = _make_id(stem, name) + line = child.start_point[0] + 1 + add_node(func_nid, f"{name}()", line) + add_edge(file_nid, func_nid, "contains", line) + if i + 1 < len(children) and children[i + 1].type == "function_body": + function_bodies.append((func_nid, children[i + 1])) + i += 2 + continue + i += 1 + + # ── Call-graph pass ────────────────────────────────────────────────────── + + label_to_nid: dict[str, str] = {} + for n in nodes: + raw = n["label"].strip("()").lstrip(".") + label_to_nid[raw.lower()] = n["id"] + + seen_call_pairs: set[tuple[str, str]] = set() + + def _resolve_call(callee_name: str, caller_nid: str, line: int) -> None: + tgt_nid = label_to_nid.get(callee_name.lower()) + if tgt_nid and tgt_nid != caller_nid: + pair = (caller_nid, tgt_nid) + if pair not in seen_call_pairs: + seen_call_pairs.add(pair) + add_edge(caller_nid, tgt_nid, "calls", line) + + def walk_calls(start_node, caller_nid: str) -> None: + stack = [start_node] + while stack: + node = stack.pop() + if node.type in ("method_signature", "function_signature"): + continue + if node.type in ("expression_statement", "initialized_variable_definition"): + _extract_calls_from_children(node.children, caller_nid) + stack.extend(node.children) + + def _extract_calls_from_children(children_list, caller_nid: str) -> None: + """Detect calls in Dart selector chains. + Pattern: identifier selector(unconditional_assignable_selector.identifier)* selector(argument_part) + """ + last_name: str | None = None + for child in children_list: + if child.type == "identifier": + last_name = _read_text(child, source) + elif child.type == "selector": + has_args = False + for sub in child.children: + if sub.type == "argument_part": + has_args = True + break + if sub.type == "unconditional_assignable_selector": + ident = _first_child_of_type(sub, "identifier") + if ident: + last_name = _read_text(ident, source) + if has_args and last_name: + _resolve_call(last_name, caller_nid, child.start_point[0] + 1) + last_name = None + + for func_nid, body_node in function_bodies: + walk_calls(body_node, func_nid) + + # Clean dangling edges + clean_edges = [e for e in edges + if e["source"] in seen_ids and + (e["target"] in seen_ids or e["relation"] in ("imports", "imports_from"))] + + return {"nodes": nodes, "edges": clean_edges} def extract_verilog(path: Path) -> dict: diff --git a/pyproject.toml b/pyproject.toml index bd0cca6e..3158c103 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "tree-sitter-objc", "tree-sitter-julia", "tree-sitter-verilog", + "tree-sitter-dart-orchard", ] [project.urls] diff --git a/tests/fixtures/sample.dart b/tests/fixtures/sample.dart new file mode 100644 index 00000000..9387b21b --- /dev/null +++ b/tests/fixtures/sample.dart @@ -0,0 +1,51 @@ +import 'package:flutter/material.dart'; +import 'dart:async'; + +abstract class Processor { + void process(); +} + +mixin Logger { + void log(String message) { + print(message); + } +} + +class DataProcessor extends Processor with Logger { + final List items = []; + + DataProcessor(); + + void addItem(String item) { + items.add(item); + } + + @override + void process() { + validate(items); + } + + List validate(List data) { + log('validating'); + return data; + } +} + +enum Status { + active, + inactive; + + String describe() { + return name; + } +} + +extension StringExt on String { + bool get isBlank => trim().isEmpty; +} + +void createProcessor() { + final p = DataProcessor(); + p.addItem('test'); + p.process(); +} diff --git a/tests/test_languages.py b/tests/test_languages.py index 680bb4e2..7aab29fc 100644 --- a/tests/test_languages.py +++ b/tests/test_languages.py @@ -1,11 +1,11 @@ -"""Tests for language extractors: Java, C, C++, Ruby, C#, Kotlin, Scala, PHP, Swift, Go, Julia.""" +"""Tests for language extractors: Java, C, C++, Ruby, C#, Kotlin, Scala, PHP, Swift, Go, Julia, Dart.""" from __future__ import annotations from pathlib import Path import pytest from graphify.extract import ( extract_java, extract_c, extract_cpp, extract_ruby, extract_csharp, extract_kotlin, extract_scala, extract_php, - extract_swift, extract_go, extract_julia, + extract_swift, extract_go, extract_julia, extract_dart, ) FIXTURES = Path(__file__).parent / "fixtures" @@ -560,3 +560,52 @@ def test_julia_no_dangling_edges(): node_ids = {n["id"] for n in r["nodes"]} for e in r["edges"]: assert e["source"] in node_ids, f"Dangling source: {e}" + + +# ── Dart ───────────────────────────────────────────────────────────────────── + +def test_dart_no_error(): + r = extract_dart(FIXTURES / "sample.dart") + assert "error" not in r + +def test_dart_finds_class(): + r = extract_dart(FIXTURES / "sample.dart") + assert any("DataProcessor" in l for l in _labels(r)) + +def test_dart_finds_abstract_class(): + r = extract_dart(FIXTURES / "sample.dart") + assert any("Processor" in l for l in _labels(r)) + +def test_dart_finds_mixin(): + r = extract_dart(FIXTURES / "sample.dart") + assert any("Logger" in l for l in _labels(r)) + +def test_dart_finds_enum(): + r = extract_dart(FIXTURES / "sample.dart") + assert any("Status" in l for l in _labels(r)) + +def test_dart_finds_methods(): + r = extract_dart(FIXTURES / "sample.dart") + labels = _labels(r) + assert any("addItem" in l for l in labels) + assert any("process" in l for l in labels) + assert any("validate" in l for l in labels) + +def test_dart_finds_function(): + r = extract_dart(FIXTURES / "sample.dart") + assert any("createProcessor" in l for l in _labels(r)) + +def test_dart_finds_imports(): + r = extract_dart(FIXTURES / "sample.dart") + assert "imports" in _relations(r) + +def test_dart_finds_calls(): + r = extract_dart(FIXTURES / "sample.dart") + call_edges = [e for e in r["edges"] if e["relation"] == "calls"] + assert len(call_edges) >= 1 + +def test_dart_no_dangling_edges(): + r = extract_dart(FIXTURES / "sample.dart") + node_ids = {n["id"] for n in r["nodes"]} + for e in r["edges"]: + assert e["source"] in node_ids, f"Dangling source: {e}"