Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 78 additions & 24 deletions src/impulse/application/use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def draw_graph(
get_top_level_package: Callable[[str], str],
build_graph: Callable[[str], grimp.ImportGraph],
viewer: ports.GraphViewer,
depth: int = 1,
) -> None:
"""
Create a file showing a graph of the supplied package.
Expand All @@ -28,28 +29,60 @@ def draw_graph(
build_graph: the function which builds the graph of the supplied package
(pass grimp.build_graph or a test double).
viewer: GraphViewer for generating the graph image and opening it.
depth: the depth of submodules to include in the graph (default: 1 for direct children).
"""
# Add current directory to the path, as this doesn't happen automatically.
sys_path.insert(0, current_directory)

top_level_package = get_top_level_package(module_name)
grimp_graph = build_graph(top_level_package)

dot = _build_dot(grimp_graph, module_name, show_import_totals, show_cycle_breakers)
dot = _build_dot(grimp_graph, module_name, show_import_totals, show_cycle_breakers, depth)

viewer.view(dot)


def _find_modules_up_to_depth(
grimp_graph: grimp.ImportGraph, module_name: str, depth: int
) -> Set[str]:
"""
Find all modules up to and including the specified depth below the given module.

For depth=1, returns direct children.
For depth=2, returns direct children AND grandchildren.
And so on.
"""
if depth < 1:
raise ValueError("Depth must be at least 1")

all_modules: set[str] = set()
current_level = {module_name}

for _ in range(depth):
next_level: set[str] = set()
for mod in current_level:
next_level.update(grimp_graph.find_children(mod))
all_modules.update(next_level)
current_level = next_level

return all_modules


class _DotGraphBuildStrategy:
def __init__(self, depth: int = 1) -> None:
self.depth = depth

def build(self, module_name: str, grimp_graph: grimp.ImportGraph) -> dotfile.DotGraph:
children = grimp_graph.find_children(module_name)
modules = _find_modules_up_to_depth(grimp_graph, module_name, self.depth)

self.prepare_graph(grimp_graph, children)
self.prepare_graph(grimp_graph, modules)

dot = dotfile.DotGraph(title=module_name, concentrate=self.should_concentrate())
for child in children:
dot.add_node(child)
for upstream, downstream in itertools.permutations(children, r=2):
dot = dotfile.DotGraph(
title=module_name, concentrate=self.should_concentrate(), depth=self.depth
)
for mod in modules:
dot.add_node(mod)
for upstream, downstream in itertools.permutations(modules, r=2):
if edge := self.build_edge(grimp_graph, upstream, downstream):
dot.add_edge(edge)

Expand All @@ -58,7 +91,7 @@ def build(self, module_name: str, grimp_graph: grimp.ImportGraph) -> dotfile.Dot
def should_concentrate(self) -> bool:
return True

def prepare_graph(self, grimp_graph: grimp.ImportGraph, children: Set[str]) -> None:
def prepare_graph(self, grimp_graph: grimp.ImportGraph, modules: Set[str]) -> None:
pass

def build_edge(
Expand All @@ -70,9 +103,9 @@ def build_edge(
class _ModuleSquashingBuildStrategy(_DotGraphBuildStrategy):
"""Fast builder for when we don't need additional data about the imports."""

def prepare_graph(self, grimp_graph: grimp.ImportGraph, children: Set[str]) -> None:
for child in children:
grimp_graph.squash_module(child)
def prepare_graph(self, grimp_graph: grimp.ImportGraph, modules: Set[str]) -> None:
for mod in modules:
grimp_graph.squash_module(mod)

def build_edge(
self, grimp_graph: grimp.ImportGraph, upstream: str, downstream: str
Expand All @@ -84,12 +117,18 @@ def build_edge(

class _ImportExpressionBuildStrategy(_DotGraphBuildStrategy):
"""Slower builder for when we want to work on the whole graph,
without squashing children.
without squashing modules.
"""

def __init__(
self, *, module_name: str, show_import_totals: bool, show_cycle_breakers: bool
self,
*,
module_name: str,
show_import_totals: bool,
show_cycle_breakers: bool,
depth: int = 1,
) -> None:
super().__init__(depth=depth)
self.module_name = module_name
self.show_import_totals = show_import_totals
self.show_cycle_breakers = show_cycle_breakers
Expand All @@ -99,22 +138,22 @@ def should_concentrate(self) -> bool:
# We need to see edge direction emphasized separately.
return not (self.show_import_totals or self.show_cycle_breakers)

def prepare_graph(self, grimp_graph: grimp.ImportGraph, children: Set[str]) -> None:
super().prepare_graph(grimp_graph, children)
def prepare_graph(self, grimp_graph: grimp.ImportGraph, modules: Set[str]) -> None:
super().prepare_graph(grimp_graph, modules)

if self.show_cycle_breakers:
self.cycle_breakers = self._get_coarse_grained_cycle_breakers(grimp_graph, children)
self.cycle_breakers = self._get_coarse_grained_cycle_breakers(grimp_graph, modules)

def _get_coarse_grained_cycle_breakers(
self, grimp_graph: grimp.ImportGraph, children: Set[str]
self, grimp_graph: grimp.ImportGraph, modules: Set[str]
) -> set[tuple[str, str]]:
# In the form (importer, imported).
coarse_grained_cycle_breakers: set[tuple[str, str]] = set()

for fine_grained_cycle_breaker in grimp_graph.nominate_cycle_breakers(self.module_name):
importer, imported = fine_grained_cycle_breaker
importer_ancestor = self._get_self_or_ancestor(candidate=importer, ancestors=children)
imported_ancestor = self._get_self_or_ancestor(candidate=imported, ancestors=children)
importer_ancestor = self._get_self_or_ancestor(candidate=importer, ancestors=modules)
imported_ancestor = self._get_self_or_ancestor(candidate=imported, ancestors=modules)

if importer_ancestor and imported_ancestor:
coarse_grained_cycle_breakers.add((importer_ancestor, imported_ancestor))
Expand All @@ -131,9 +170,19 @@ def _get_self_or_ancestor(candidate: str, ancestors: Set[str]) -> str | None:
def build_edge(
self, grimp_graph: grimp.ImportGraph, upstream: str, downstream: str
) -> dotfile.Edge | None:
if grimp_graph.direct_import_exists(
importer=downstream, imported=upstream, as_packages=True
):
# For depth > 1, we can't use as_packages=True because modules may share
# descendants (e.g., foo.blue and foo.blue.alpha are both in our set).
# In that case, only check for direct imports between exact modules.
if self.depth > 1:
import_exists = grimp_graph.direct_import_exists(
importer=downstream, imported=upstream, as_packages=False
)
else:
import_exists = grimp_graph.direct_import_exists(
importer=downstream, imported=upstream, as_packages=True
)

if import_exists:
if self.show_import_totals:
number_of_imports = self._count_imports_between_packages(
grimp_graph, importer=downstream, imported=upstream
Expand Down Expand Up @@ -183,15 +232,20 @@ def _build_dot(
module_name: str,
show_import_totals: bool,
show_cycle_breakers: bool,
depth: int = 1,
) -> dotfile.DotGraph:
strategy: _DotGraphBuildStrategy
if show_import_totals or show_cycle_breakers:
# Use ImportExpressionBuildStrategy when:
# - show_import_totals or show_cycle_breakers is enabled, OR
# - depth > 1 (squashing would remove deeper modules we want to show)
if show_import_totals or show_cycle_breakers or depth > 1:
strategy = _ImportExpressionBuildStrategy(
module_name=module_name,
show_import_totals=show_import_totals,
show_cycle_breakers=show_cycle_breakers,
depth=depth,
)
else:
strategy = _ModuleSquashingBuildStrategy()
strategy = _ModuleSquashingBuildStrategy(depth=depth)

return strategy.build(module_name, grimp_graph)
8 changes: 8 additions & 0 deletions src/impulse/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,20 @@ def main():
help="Output format (default to html).",
)
@click.option("--force-console", is_flag=True, help="Force the use of the console output.")
@click.option(
"--depth",
type=int,
default=1,
help="Depth of submodules to include in the graph (default: 1 for direct children).",
)
@click.argument("module_name", type=str)
def drawgraph(
module_name: str,
show_import_totals: bool,
show_cycle_breakers: bool,
force_console: bool,
format: str,
depth: int,
) -> None:
viewer: ports.GraphViewer
if format == "html":
Expand All @@ -58,6 +65,7 @@ def drawgraph(
module_name=module_name,
show_import_totals=show_import_totals,
show_cycle_breakers=show_cycle_breakers,
depth=depth,
sys_path=sys.path,
current_directory=os.getcwd(),
get_top_level_package=adapters.get_top_level_package,
Expand Down
25 changes: 18 additions & 7 deletions src/impulse/dotfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ class Edge:
emphasized: bool = False

def __str__(self) -> str:
return f'"{DotGraph.render_module(self.source)}" -> "{DotGraph.render_module(self.destination)}"{self._render_attrs()}\n'
return self.render(base_module="")

def render(self, base_module: str) -> str:
return f'"{DotGraph.render_module(self.source, base_module)}" -> "{DotGraph.render_module(self.destination, base_module)}"{self._render_attrs()}\n'

def _render_attrs(self) -> str:
attrs: dict[str, str] = {}
Expand All @@ -32,11 +35,12 @@ class DotGraph:
https://en.wikipedia.org/wiki/DOT_(graph_description_language)
"""

def __init__(self, title: str, concentrate: bool = True) -> None:
def __init__(self, title: str, concentrate: bool = True, depth: int = 1) -> None:
self.title = title
self.nodes: set[str] = set()
self.edges: set[Edge] = set()
self.concentrate = concentrate
self.depth = depth

def add_node(self, name: str) -> None:
self.nodes.add(name)
Expand All @@ -54,12 +58,19 @@ def render(self) -> str:
}}""")

def _render_nodes(self) -> str:
return "\n".join(f'"{self.render_module(node)}"\n' for node in sorted(self.nodes))
return "\n".join(
f'"{self.render_module(node, self.title)}"\n' for node in sorted(self.nodes)
)

def _render_edges(self) -> str:
return "\n".join(str(edge) for edge in sorted(self.edges))
return "\n".join(edge.render(self.title) for edge in sorted(self.edges))

@staticmethod
def render_module(module: str) -> str:
# Render as relative module.
return f".{module.split('.')[-1]}"
def render_module(module: str, base_module: str = "") -> str:
# Render as relative module by stripping the base module prefix.
if base_module and module.startswith(base_module + "."):
relative = module[len(base_module) :]
return relative # Already starts with "."
else:
# Fallback: show as relative with just the last component
return f".{module.split('.')[-1]}"
60 changes: 60 additions & 0 deletions tests/unit/application/test_use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,63 @@ def test_draw_graph_show_cycle_breakers(self):
),
Edge("mypackage.foo.red", "mypackage.foo.blue", emphasized=True),
}

def test_draw_graph_depth_2(self):
"""Test that depth=2 shows children AND grandchildren of the module."""

def build_graph_with_depth(package_name: str) -> grimp.ImportGraph:
graph = grimp.ImportGraph()
graph.add_module(package_name)
graph.add_module(SOME_MODULE)

# Create a hierarchy: foo.blue, foo.green, foo.blue.alpha, foo.blue.beta, foo.green.gamma
for child in ("blue", "green"):
graph.add_module(f"{SOME_MODULE}.{child}")
for grandchild in ("alpha", "beta"):
graph.add_module(f"{SOME_MODULE}.blue.{grandchild}")
graph.add_module(f"{SOME_MODULE}.green.gamma")

# Add imports at the grandchild level
graph.add_import(
importer=f"{SOME_MODULE}.blue.alpha",
imported=f"{SOME_MODULE}.green.gamma",
)
graph.add_import(
importer=f"{SOME_MODULE}.blue.beta",
imported=f"{SOME_MODULE}.green.gamma",
)
# Add import at the child level
graph.add_import(
importer=f"{SOME_MODULE}.blue",
imported=f"{SOME_MODULE}.green",
)
return graph

viewer = SpyGraphViewer()

use_cases.draw_graph(
SOME_MODULE,
show_import_totals=False,
show_cycle_breakers=False,
sys_path=[],
current_directory="/cwd",
get_top_level_package=fake_get_top_level_package_non_namespace,
build_graph=build_graph_with_depth,
viewer=viewer,
depth=2,
)

assert viewer.called_with_dot.depth == 2
# depth=2 includes both children (depth 1) AND grandchildren (depth 2)
assert viewer.called_with_dot.nodes == {
"mypackage.foo.blue",
"mypackage.foo.green",
"mypackage.foo.blue.alpha",
"mypackage.foo.blue.beta",
"mypackage.foo.green.gamma",
}
assert viewer.called_with_dot.edges == {
Edge("mypackage.foo.blue", "mypackage.foo.green"),
Edge("mypackage.foo.blue.alpha", "mypackage.foo.green.gamma"),
Edge("mypackage.foo.blue.beta", "mypackage.foo.green.gamma"),
}