|
| 1 | +--- |
| 2 | +title: "Fixing Import Loops" |
| 3 | +description: "Learn how to identify and fix problematic import loops using Codegen." |
| 4 | +icon: "arrows-rotate" |
| 5 | +iconType: "solid" |
| 6 | +--- |
| 7 | +<Frame caption="Import loops in pytorch/torchgen/model.py"> |
| 8 | + <iframe |
| 9 | + width="100%" |
| 10 | + height="500px" |
| 11 | + scrolling="no" |
| 12 | + src={`https://www.codegen.sh/embedded/graph/?id=8b575318-ff94-41f1-94df-6e21d9de45d1&zoom=1&targetNodeName=model`} |
| 13 | + className="rounded-xl" |
| 14 | + style={{ |
| 15 | + backgroundColor: "#15141b", |
| 16 | + }} |
| 17 | + ></iframe> |
| 18 | +</Frame> |
| 19 | + |
| 20 | + |
| 21 | +Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain. |
| 22 | + |
| 23 | +In this tutorial, we'll explore how to identify and fix problematic import cycles using Codegen. |
| 24 | + |
| 25 | +<Info> |
| 26 | +You can find the complete example code in our [examples repository](https://github.com/codegen-sh/codegen-examples/tree/main/examples/removing_import_loops_in_pytorch). |
| 27 | +</Info> |
| 28 | + |
| 29 | +## Overview |
| 30 | + |
| 31 | +The steps to identify and fix import loops are as follows: |
| 32 | +1. Detect import loops |
| 33 | +2. Visualize them |
| 34 | +3. Identify problematic cycles with mixed static/dynamic imports |
| 35 | +4. Fix these cycles using Codegen |
| 36 | + |
| 37 | +# Step 1: Detect Import Loops |
| 38 | +- Create a graph |
| 39 | +- Loop through imports in the codebase and add edges between the import files |
| 40 | +- Find strongly connected components using Networkx (the import loops) |
| 41 | +```python |
| 42 | +G = nx.MultiDiGraph() |
| 43 | +
|
| 44 | +# Add all edges to the graph |
| 45 | +for imp in codebase.imports: |
| 46 | + if imp.from_file and imp.to_file: |
| 47 | + edge_color = "red" if imp.is_dynamic else "black" |
| 48 | + edge_label = "dynamic" if imp.is_dynamic else "static" |
| 49 | +
|
| 50 | + # Store the import statement and its metadata |
| 51 | + G.add_edge( |
| 52 | + imp.to_file.filepath, |
| 53 | + imp.from_file.filepath, |
| 54 | + color=edge_color, |
| 55 | + label=edge_label, |
| 56 | + is_dynamic=imp.is_dynamic, |
| 57 | + import_statement=imp, # Store the whole import object |
| 58 | + key=id(imp.import_statement), |
| 59 | + ) |
| 60 | +# Find strongly connected components |
| 61 | +cycles = [scc for scc in nx.strongly_connected_components(G) if len(scc) > 1] |
| 62 | +
|
| 63 | +print(f"🔄 Found {len(cycles)} import cycles:") |
| 64 | +for i, cycle in enumerate(cycles, 1): |
| 65 | + print(f"\nCycle #{i}:") |
| 66 | + print(f"Size: {len(cycle)} files") |
| 67 | +
|
| 68 | + # Create subgraph for this cycle to count edges |
| 69 | + cycle_subgraph = G.subgraph(cycle) |
| 70 | +
|
| 71 | + # Count total edges |
| 72 | + total_edges = cycle_subgraph.number_of_edges() |
| 73 | + print(f"Total number of imports in cycle: {total_edges}") |
| 74 | +
|
| 75 | + # Count dynamic and static imports separately |
| 76 | + dynamic_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get("color") == "red") |
| 77 | + static_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get("color") == "black") |
| 78 | +
|
| 79 | + print(f"Number of dynamic imports: {dynamic_imports}") |
| 80 | + print(f"Number of static imports: {static_imports}") |
| 81 | +``` |
| 82 | + |
| 83 | + |
| 84 | +## Understanding Import Cycles |
| 85 | + |
| 86 | +Not all import cycles are problematic! Here's an example of a cycle that one may think would cause an error but it does not because due to using dynamic imports. |
| 87 | + |
| 88 | +```python |
| 89 | +# top level import in in APoT_tensor.py |
| 90 | +from quantizer.py import objectA |
| 91 | +``` |
| 92 | + |
| 93 | +```python |
| 94 | +# dynamic import in quantizer.py |
| 95 | +def some_func(): |
| 96 | + # dynamic import (evaluated when some_func() is called) |
| 97 | + from APoT_tensor.py import objectB |
| 98 | +``` |
| 99 | + |
| 100 | +<img src="/images/valid-import-loop.png" /> |
| 101 | + |
| 102 | +A dynamic import is an import defined inside of a function, method or any executable body of code which delays the import execution until that function, method or body of code is called. |
| 103 | + |
| 104 | +You can use `imp.is_dynamic` to check if the import is dynamic allowing you to investigate imports that are handled more intentionally. |
| 105 | + |
| 106 | +# Step 2: Visualize Import Loops |
| 107 | +- Create a new subgraph to visualize one cycle |
| 108 | +- color and label the edges based on their type (dynamic/static) |
| 109 | +- visualize the cycle graph using `codebase.visualize(graph)` |
| 110 | + |
| 111 | +```python |
| 112 | +cycle = cycles[0] |
| 113 | +
|
| 114 | +def create_single_loop_graph(cycle): |
| 115 | + cycle_graph = nx.MultiDiGraph() # Changed to MultiDiGraph to support multiple edges |
| 116 | + cycle = list(cycle) |
| 117 | + for i in range(len(cycle)): |
| 118 | + for j in range(len(cycle)): |
| 119 | + # Get all edges between these nodes from original graph |
| 120 | + edge_data_dict = G.get_edge_data(cycle[i], cycle[j]) |
| 121 | + if edge_data_dict: |
| 122 | + # For each edge between these nodes |
| 123 | + for edge_key, edge_data in edge_data_dict.items(): |
| 124 | + # Add edge with all its attributes to cycle graph |
| 125 | + cycle_graph.add_edge(cycle[i], cycle[j], **edge_data) |
| 126 | + return cycle_graph |
| 127 | +
|
| 128 | +
|
| 129 | +cycle_graph = create_single_loop_graph(cycle) |
| 130 | +codebase.visualize(cycle_graph) |
| 131 | +``` |
| 132 | + |
| 133 | +<Frame caption="Import loops in pytorch/torchgen/model.py"> |
| 134 | + <iframe |
| 135 | + width="100%" |
| 136 | + height="500px" |
| 137 | + scrolling="no" |
| 138 | + src={`https://www.codegen.sh/embedded/graph/?id=8b575318-ff94-41f1-94df-6e21d9de45d1&zoom=1&targetNodeName=model`} |
| 139 | + className="rounded-xl" |
| 140 | + style={{ |
| 141 | + backgroundColor: "#15141b", |
| 142 | + }} |
| 143 | + ></iframe> |
| 144 | +</Frame> |
| 145 | + |
| 146 | + |
| 147 | +# Step 3: Identify problematic cycles with mixed static & dynamic imports |
| 148 | + |
| 149 | +The import loops that we are really concerned about are those that have mixed static/dynamic imports. |
| 150 | + |
| 151 | +Here's an example of a problematic cycle that we want to fix: |
| 152 | + |
| 153 | +```python |
| 154 | +# In flex_decoding.py |
| 155 | +from .flex_attention import ( |
| 156 | + compute_forward_block_mn, |
| 157 | + compute_forward_inner, |
| 158 | + # ... more static imports |
| 159 | +) |
| 160 | +
|
| 161 | +# Also in flex_decoding.py |
| 162 | +def create_flex_decoding_kernel(*args, **kwargs): |
| 163 | + from .flex_attention import set_head_dim_values # dynamic import |
| 164 | +``` |
| 165 | + |
| 166 | +It's clear that there is both a top level and a dynamic import that imports from the *same* module. Thus, this can cause issues if not handled carefully. |
| 167 | + |
| 168 | +<img src="/images/problematic-import-loop.png" /> |
| 169 | + |
| 170 | +Let's find these problematic cycles: |
| 171 | + |
| 172 | +```python |
| 173 | +def find_problematic_import_loops(G, sccs): |
| 174 | + """Find cycles where files have both static and dynamic imports between them.""" |
| 175 | + problematic_cycles = [] |
| 176 | +
|
| 177 | + for i, scc in enumerate(sccs): |
| 178 | + if i == 2: # skipping the second import loop as it's incredibly long (it's also invalid) |
| 179 | + continue |
| 180 | + mixed_import_files = {} # (from_file, to_file) -> {dynamic: count, static: count} |
| 181 | +
|
| 182 | + # Check all file pairs in the cycle |
| 183 | + for from_file in scc: |
| 184 | + for to_file in scc: |
| 185 | + if G.has_edge(from_file, to_file): |
| 186 | + # Get all edges between these files |
| 187 | + edges = G.get_edge_data(from_file, to_file) |
| 188 | +
|
| 189 | + # Count imports by type |
| 190 | + dynamic_count = sum(1 for e in edges.values() if e["color"] == "red") |
| 191 | + static_count = sum(1 for e in edges.values() if e["color"] == "black") |
| 192 | +
|
| 193 | + # If we have both types between same files, this is problematic |
| 194 | + if dynamic_count > 0 and static_count > 0: |
| 195 | + mixed_import_files[(from_file, to_file)] = {"dynamic": dynamic_count, "static": static_count, "edges": edges} |
| 196 | +
|
| 197 | + if mixed_import_files: |
| 198 | + problematic_cycles.append({"files": scc, "mixed_imports": mixed_import_files, "index": i}) |
| 199 | +
|
| 200 | + # Print findings |
| 201 | + print(f"Found {len(problematic_cycles)} cycles with mixed imports:") |
| 202 | + for i, cycle in enumerate(problematic_cycles): |
| 203 | + print(f"\n⚠️ Problematic Cycle #{i + 1}:") |
| 204 | + print(f"\n⚠️ Index #{cycle['index']}:") |
| 205 | + print(f"Size: {len(cycle['files'])} files") |
| 206 | +
|
| 207 | + for (from_file, to_file), data in cycle["mixed_imports"].items(): |
| 208 | + print("\n📁 Mixed imports detected:") |
| 209 | + print(f" From: {from_file}") |
| 210 | + print(f" To: {to_file}") |
| 211 | + print(f" Dynamic imports: {data['dynamic']}") |
| 212 | + print(f" Static imports: {data['static']}") |
| 213 | +
|
| 214 | + return problematic_cycles |
| 215 | +
|
| 216 | +problematic_cycles = find_problematic_import_loops(G, cycles) |
| 217 | +``` |
| 218 | + |
| 219 | +# Step 4: Fix the loop by moving the shared symbols to a separate `utils.py` file |
| 220 | +One common fix to this problem to break this cycle is to move all the shared symbols to a separate `utils.py` file. We can do this using the method `symbol.move_to_file`: |
| 221 | + |
| 222 | +```python |
| 223 | +# Create new utils file |
| 224 | +utils_file = codebase.create_file("torch/_inductor/kernel/flex_utils.py") |
| 225 | +
|
| 226 | +# Get the two files involved in the import cycle |
| 227 | +decoding_file = codebase.get_file("torch/_inductor/kernel/flex_decoding.py") |
| 228 | +attention_file = codebase.get_file("torch/_inductor/kernel/flex_attention.py") |
| 229 | +attention_file_path = "torch/_inductor/kernel/flex_attention.py" |
| 230 | +decoding_file_path = "torch/_inductor/kernel/flex_decoding.py" |
| 231 | +
|
| 232 | +# Track symbols to move |
| 233 | +symbols_to_move = set() |
| 234 | +
|
| 235 | +# Find imports from flex_attention in flex_decoding |
| 236 | +for imp in decoding_file.imports: |
| 237 | + if imp.from_file and imp.from_file.filepath == attention_file_path: |
| 238 | + # Get the actual symbol from flex_attention |
| 239 | + if imp.imported_symbol: |
| 240 | + symbols_to_move.add(imp.imported_symbol) |
| 241 | +
|
| 242 | +# Move identified symbols to utils file |
| 243 | +for symbol in symbols_to_move: |
| 244 | + symbol.move_to_file(utils_file) |
| 245 | +
|
| 246 | +print(f"🔄 Moved {len(symbols_to_move)} symbols to flex_utils.py") |
| 247 | +for symbol in symbols_to_move: |
| 248 | + print(symbol.name) |
| 249 | +``` |
| 250 | + |
| 251 | +```python |
| 252 | +# run this command to have the changes take effect in the codebase |
| 253 | +codebase.commit() |
| 254 | +``` |
| 255 | + |
| 256 | +Next Steps |
| 257 | +Verify all tests pass after the migration and fix other problematic import loops using the suggested strategies: |
| 258 | + 1. Move the shared symbols to a separate file |
| 259 | + 2. If a module needs imports only for type hints, consider using `if TYPE_CHECKING` from the `typing` module |
| 260 | + 3. Use lazy imports using `importlib` to load imports dynamically |
0 commit comments