|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +from codegen.extensions.graph.utils import Node, NodeLabel, Relation, RelationLabel, SimpleGraph |
| 4 | +from codegen.sdk.code_generation.doc_utils.utils import safe_get_class |
| 5 | +from codegen.sdk.core.class_definition import Class |
| 6 | +from codegen.sdk.core.external_module import ExternalModule |
| 7 | +from codegen.sdk.core.function import Function |
| 8 | +from codegen.sdk.python.class_definition import PyClass |
| 9 | + |
| 10 | + |
| 11 | +def create_codebase_graph(codebase): |
| 12 | + """Create a SimpleGraph representing the codebase structure.""" |
| 13 | + # Initialize graph |
| 14 | + graph = SimpleGraph() |
| 15 | + |
| 16 | + # Track existing nodes by name to prevent duplicates |
| 17 | + node_registry = {} # name -> node_id mapping |
| 18 | + |
| 19 | + def get_or_create_node(name: str, label: NodeLabel, parent_name: Optional[str] = None, properties: dict | None = None): |
| 20 | + """Get existing node or create new one if it doesn't exist.""" |
| 21 | + full_name = f"{parent_name}.{name}" if parent_name and parent_name != "Class" else name |
| 22 | + if full_name in node_registry: |
| 23 | + return graph.nodes[node_registry[full_name]] |
| 24 | + |
| 25 | + node = Node(name=name, full_name=full_name, label=label.value, properties=properties or {}) |
| 26 | + node_registry[full_name] = node.id |
| 27 | + graph.add_node(node) |
| 28 | + return node |
| 29 | + |
| 30 | + def create_class_node(class_def): |
| 31 | + """Create a node for a class definition.""" |
| 32 | + return get_or_create_node( |
| 33 | + name=class_def.name, |
| 34 | + label=NodeLabel.CLASS, |
| 35 | + properties={ |
| 36 | + "filepath": class_def.filepath if hasattr(class_def, "filepath") else "", |
| 37 | + "source": class_def.source if hasattr(class_def, "source") else "", |
| 38 | + "type": "class", |
| 39 | + }, |
| 40 | + ) |
| 41 | + |
| 42 | + def create_function_node(func): |
| 43 | + """Create a node for a function/method.""" |
| 44 | + class_name = None |
| 45 | + if func.is_method: |
| 46 | + class_name = func.parent_class.name |
| 47 | + |
| 48 | + return get_or_create_node( |
| 49 | + name=func.name, |
| 50 | + label=NodeLabel.METHOD if class_name else NodeLabel.FUNCTION, |
| 51 | + parent_name=class_name, |
| 52 | + properties={ |
| 53 | + "filepath": func.filepath if hasattr(func, "filepath") else "", |
| 54 | + "is_async": func.is_async if hasattr(func, "is_async") else False, |
| 55 | + "source": func.source if hasattr(func, "source") else "", |
| 56 | + "type": "method" if class_name else "function", |
| 57 | + }, |
| 58 | + ) |
| 59 | + |
| 60 | + def create_function_call_node(func_call): |
| 61 | + """Create a node for a function call.""" |
| 62 | + func_def = func_call.function_definition |
| 63 | + if not func_def: |
| 64 | + return None |
| 65 | + if isinstance(func_def, ExternalModule): |
| 66 | + parent_class = safe_get_class(codebase, func_def.name) |
| 67 | + if parent_class and parent_class.get_method(func_call.name): |
| 68 | + return create_function_node(parent_class.get_method(func_call.name)) |
| 69 | + else: |
| 70 | + return None |
| 71 | + |
| 72 | + call_node = None |
| 73 | + if isinstance(func_def, Function): |
| 74 | + call_node = create_function_node(func_def) |
| 75 | + |
| 76 | + elif isinstance(func_def, Class): |
| 77 | + call_node = create_class_node(func_def) |
| 78 | + |
| 79 | + return call_node |
| 80 | + |
| 81 | + # Process all classes |
| 82 | + for class_def in codebase.classes: |
| 83 | + class_node = create_class_node(class_def) |
| 84 | + |
| 85 | + # Process methods |
| 86 | + methods = class_def.methods |
| 87 | + for method in methods: |
| 88 | + method_node = create_function_node(method) |
| 89 | + |
| 90 | + # Add DEFINES relation |
| 91 | + defines_relation = Relation( |
| 92 | + label=RelationLabel.DEFINES.value, source_id=class_node.id, target_id=method_node.id, properties={"relationship_description": "The parent class defines the method."} |
| 93 | + ) |
| 94 | + graph.add_relation(defines_relation) |
| 95 | + |
| 96 | + for call in method.function_calls: |
| 97 | + call_node = create_function_call_node(call) |
| 98 | + if call_node and call_node != method_node: |
| 99 | + call_relation = Relation( |
| 100 | + label=RelationLabel.CALLS.value, source_id=method_node.id, target_id=call_node.id, properties={"relationship_description": f"The method calls the {call_node.label}."} |
| 101 | + ) |
| 102 | + graph.add_relation(call_relation) |
| 103 | + |
| 104 | + # Add inheritance relations |
| 105 | + if class_def.parent_classes: |
| 106 | + for parent in class_def.parent_classes: |
| 107 | + if not isinstance(parent, PyClass): |
| 108 | + try: |
| 109 | + parent = codebase.get_class(parent.name, optional=True) |
| 110 | + if not parent: |
| 111 | + continue |
| 112 | + except Exception as e: |
| 113 | + print(f"parent not found: {e}") |
| 114 | + continue |
| 115 | + if not hasattr(parent, "name"): |
| 116 | + continue |
| 117 | + parent_node = create_class_node(parent) |
| 118 | + |
| 119 | + inherits_relation = Relation( |
| 120 | + label=RelationLabel.INHERITS_FROM.value, |
| 121 | + source_id=class_node.id, |
| 122 | + target_id=parent_node.id, |
| 123 | + properties={"relationship_description": "The child class inherits from the parent class."}, |
| 124 | + ) |
| 125 | + graph.add_relation(inherits_relation) |
| 126 | + |
| 127 | + for func in codebase.functions: |
| 128 | + func_node = create_function_node(func) |
| 129 | + for call in func.function_calls: |
| 130 | + call_node = create_function_call_node(call) |
| 131 | + if call_node and call_node != func_node: |
| 132 | + call_relation = Relation( |
| 133 | + label=RelationLabel.CALLS.value, source_id=func_node.id, target_id=call_node.id, properties={"relationship_description": f"The function calls the {call_node.label}."} |
| 134 | + ) |
| 135 | + graph.add_relation(call_relation) |
| 136 | + |
| 137 | + return graph |
0 commit comments