forked from yarongmu-google/MLSys
-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
enhancementNew feature or requestNew feature or request
Description
P5: Recomputation for diamond graphs
Problem
When a tensor is consumed by two downstream branches (diamond/skip-connection pattern), spilling it to DRAM costs `2 * size / bandwidth` per consumer. Recomputing the tensor (including it in multiple subgraphs) may be cheaper if the compute cost is low relative to the DRAM cost.
Reference
PROBLEM.md Example 3B: recomputation achieves 6,276.8 latency vs 11,468.8 for spilling (45% improvement).
Proposed Approach
- Identify diamond patterns: tensor T produced by op A, consumed by ops B and C in different subgraphs
- Compare: spill cost (write T + read T twice) vs recompute cost (run A twice, T stays ephemeral)
- If recompute is cheaper, include op A in both subgraphs
Acceptance Criteria
- Diamond pattern detection in DAG analysis
- Cost comparison: spill vs recompute
- Ops may appear in multiple subgraphs (already allowed in E2E validator)
- Track A (Rust) and Track B (Python) both updated
- Validated against Example 3B (6,276.8 latency target)
Dependencies
Depends on #16 (cost-based fusion) — requires multiple subgraphs to have boundaries where recomputation applies.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request