From 9865e37a026297dcbad9d1ccf52af2427d0c4320 Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Wed, 4 Mar 2026 10:34:13 -0500 Subject: [PATCH] lazy disk 3 --- docs/disk_assign_design.md | 237 ++++++++++++++++++++++++++++++++++ tinygrad/engine/schedule.py | 6 +- tinygrad/schedule/indexing.py | 3 +- tinygrad/schedule/rangeify.py | 21 ++- tinygrad/tensor.py | 25 +++- tinygrad/uop/ops.py | 2 +- 6 files changed, 283 insertions(+), 11 deletions(-) create mode 100644 docs/disk_assign_design.md diff --git a/docs/disk_assign_design.md b/docs/disk_assign_design.md new file mode 100644 index 0000000000000..718281b912c5b --- /dev/null +++ b/docs/disk_assign_design.md @@ -0,0 +1,237 @@ +# Design: Proper DISK Assign (Removing the Hack) + +## Problem + +`tensor.py:assign` has a hack for DISK targets that bypasses the schedule entirely: + +```python +if is_disk: + self._buffer().copyin(x._data()) + return self +``` + +This directly realizes `x`, gets its raw bytes, and copies them to the DISK buffer. While it works, it: +- Bypasses the schedule (can't be batched/optimized) +- Doesn't compose with JIT capture +- Is a special case that doesn't go through the graph rewrite system + +**Goal**: Remove this hack and use proper ASSIGN + rewrite rules so the copy from compute device to disk goes through the normal schedule as a COPY ExecItem. + +## Current Infrastructure + +### How COPY-in-ASSIGN Already Works (for same-device) + +The schedule already supports ASSIGN with COPY/BUFFER_VIEW sources. The flow: + +1. **`realize_assign_src`** (`schedule/indexing.py:20`): When ASSIGN source is COPY/BUFFER_VIEW/ENCDEC, it's **unrealized** (removed from realize_map). This means the COPY stays as the ASSIGN source value instead of being independently bufferized. + +2. **`bufferize_to_store`** (`schedule/rangeify.py:325`): Creates `INDEX(BUFFER, ...).store(assign_src).end(ranges)`, wrapping the assign target buffer in AFTER. + +3. **`split_store`** (`schedule/rangeify.py:477-481`): Detects when the stored value is COPY or BUFFER_VIEW: + ```python + if stored.op in {Ops.COPY, Ops.BUFFER_VIEW}: + ret = stored.replace(src=stored.src + ret.ended_ranges) + ``` + This converts the kernel from a SINK (compute kernel) to a COPY/BUFFER_VIEW ExecItem. + +4. The COPY ExecItem's **output buffer** comes from the AFTER structure — it's the ASSIGN target buffer (the existing DISK buffer). + +### How DISK COPY Already Works (for `tensor.to("disk:...")`) + +- **`disk_copy_is_buffer`** (`engine/allocations.py:19`): For COPY-to-disk, creates a new DISK buffer in `buffer_map`. +- **`pm_finalize_call`** (`engine/allocations.py:162`): COPY-to-disk UOps are appended to the assigns list. +- **`BufferCopy`** (`engine/realize.py`): At execution time, uses `copyin` or optimized paths (io_uring, readinto) to transfer data. + +### Existing ASSIGN Rewrite Rules (`schedule/rangeify.py:125-138`) + +- **Collapse nested ASSIGN**: `ASSIGN(target, ASSIGN(target, src))` → `ASSIGN(target, src)` +- **Move bitcast to source**: `ASSIGN(BITCAST(target), src)` → `ASSIGN(target, src.bitcast(target.dtype))` +- **Normalize ASSIGN chains**: unwrap chained ASSIGN targets +- **Fix hazards**: make source contiguous if it contains hazardous movement ops on dest + +### Free Bitcast for DISK + +DISK supports free bitcast through `late_buffer_view` (`schedule/rangeify.py:266`): +- `BUFFERIZE(BITCAST(x))` on DISK → `BUFFER_VIEW(base, (size, offset))` +- BUFFER_VIEW is zero-copy: just a different dtype interpretation at an offset +- Used by `safe_save` for writing header length: `t[0:8].bitcast(dtypes.int64).assign([len(j)])` + +## Experiments + +### What Works Now (without the hack) + +**Full assign with buffer source** — PASS: +```python +dt = Tensor.empty(4, device="disk:...", dtype=dtypes.int32) +src = Tensor([10,20,30,40], dtype=dtypes.int32) +# Manually create ASSIGN(DISK_BUF, COPY(src, DISK)) +dt.uop = dt.uop.assign(src.uop.copy_to_device(dt.device)) +dt.realize() # Works! Data correctly written to disk. +``` + +The schedule produces 2 COPYs: +1. COPY from PYTHON → METAL (realize the list) +2. COPY from METAL → DISK (write to disk) + +The DISK buffer in the schedule IS the existing tensor's buffer (`sched[1].buf[0] is dt.uop.buffer` → True). + +**Full assign with CONST source (via contiguous)** — PASS: +```python +dt = Tensor.empty(4, device="disk:...", dtype=dtypes.int32) +src = Tensor.full((4,), 42, dtype=dtypes.int32) +# contiguous() prevents early_fixup_const_copy from optimizing COPY(CONST,DISK) → CONST(DISK) +dt.uop = dt.uop.assign(src.uop.contiguous().copy_to_device(dt.device)) +dt.realize() # Works! +``` + +### What Fails Now + +**Bare ASSIGN(DISK_BUF, src)** — FAIL: +```python +dt.assign(Tensor.full((4,), 42)).realize() +# NotImplementedError: needs a renderer +``` +DISK has no renderer, so kernels can't execute on it. + +**ASSIGN with CONST source without contiguous** — FAIL: +```python +# COPY(CONST(42, METAL), DISK) is optimized to CONST(42, DISK) by early_fixup_const_copy +# Then ASSIGN(DISK_BUF, CONST(42, DISK)) creates kernel on DISK → fails +``` + +**Slice assign with COPY** — WRONG OFFSET: +```python +dt[2:5].assign(...) # Writes to offset 0 instead of offset 2 +``` +The COPY creates a new buffer at offset 0. The slice offset info from the ASSIGN target (SHRINK) is lost when the kernel becomes a COPY. + +## Proposed Design + +### Where to Make the Change + +**In `earliest_rewrites`** (`schedule/rangeify.py`), add a rule that converts DISK ASSIGN sources to COPY. This runs inside `get_kernel_graph` (after `transform_to_call`), so: +- It's after `add_tags` (no interference with `disk_copy_is_buffer`) +- It's after `pm_early_transform_tensor_graph` (bitcast rules have already fired) +- PARAMs already have `_device` set, so we can check if target is DISK +- The COPY won't be processed by `pm_finalize_call`'s standalone COPY-to-disk rule (that only runs in `transform_to_call`) + +### The Rule + +```python +# In earliest_rewrites (schedule/rangeify.py) +def disk_assign_wrap_copy(assign:UOp): + """For DISK assigns, wrap the source in a COPY so it becomes a COPY ExecItem instead of a kernel.""" + target = assign.src[0] + # Walk through ASSIGN/BITCAST/AFTER to find the base buffer + base = target + while base.op in {Ops.ASSIGN, Ops.BITCAST, Ops.AFTER}: base = base.src[0].base + if base.op not in {Ops.BUFFER, Ops.PARAM}: return None + device = base._device + if not (isinstance(device, str) and device.startswith("DISK")): return None + src = assign.src[1] + # If source is already a COPY to this device, no change needed + if src.op is Ops.COPY and src._device == device: return None + # Wrap source in COPY to disk + return assign.replace(src=(target, src.copy_to_device(device))) + +(UPat(Ops.ASSIGN, name="assign"), disk_assign_wrap_copy), +``` + +### Why This Works + +After the rule fires, the graph is: +``` +ASSIGN(INDEX(DISK_BUF, offset...), COPY(src_indexed, DISK_DEVICE)) +``` + +1. **`realize_assign_src`** unrealizes the COPY (doesn't get its own buffer) +2. **`bufferize_to_store`** creates: `DISK_BUF.after(INDEX(...).store(COPY(src, DISK)).end(ranges))` +3. **`split_store`** sees `stored.op is Ops.COPY` → converts to COPY kernel +4. The COPY kernel's **output buffer = existing DISK_BUF** (from the AFTER structure) +5. The COPY kernel's **input buffer = src buffer** (on compute device) +6. The **offset is preserved** in the INDEX/ranges, which get passed to the COPY + +### Slice Assign Flow + +``` +dt[2:5].assign(Tensor([99,99,99])) +``` + +1. Tensor graph: `ASSIGN(SHRINK(DISK_BUF, 2, 5), CONST(99))` +2. After rangeify: `ASSIGN(INDEX(DISK_BUF, range(2,5)), src_indexed)` +3. Our rule: `ASSIGN(INDEX(DISK_BUF, range(2,5)), COPY(src_indexed, DISK))` +4. bufferize_to_store: `DISK_BUF.after(INDEX(DISK_BUF, range(2,5)).store(COPY(src, DISK)).end(range))` +5. split_store: COPY kernel with output = DISK_BUF +6. BufferCopy writes data to the DISK buffer — the **offset is handled by the INDEX** + +**Key**: The INDEX(DISK_BUF, offset_range) preserves the slice offset. The COPY writes to the correct region of the DISK buffer because the AFTER targets the full DISK_BUF, and the INDEX+ranges encode where within the buffer to write. + +### Bitcast Assign Flow + +``` +t[0:8].bitcast(dtypes.int64).assign([12345]) +``` + +1. Existing bitcast rule fires first: `ASSIGN(BITCAST(target), src)` → `ASSIGN(target, src.bitcast(target.dtype))` + - Result: `ASSIGN(SHRINK(DISK_BUF, 0, 8), BITCAST(CONST(12345, int64), uint8))` +2. After rangeify, the BITCAST on the source becomes part of the indexed expression +3. Our rule wraps in COPY: `ASSIGN(INDEX(DISK_BUF, ...), COPY(src_with_bitcast, DISK))` +4. split_store → COPY kernel +5. BufferCopy copies the raw bytes to the correct offset in the DISK buffer + +**No compute kernel needed for the bitcast** — the bitcast is just a reinterpretation of bytes. The COPY transfers raw bytes regardless of dtype. + +### CONST Source Handling + +For CONST sources (e.g., `Tensor.full`): +- The COPY source is the CONST expression +- The CONST gets bufferized normally (materialized on the compute device) +- Then the COPY transfers from compute device to DISK + +**Important**: The `early_fixup_const_copy` rule (`pm_early_transform_tensor_graph:137`) runs in `transform_to_call` BEFORE our rule. It converts `COPY(CONST, device)` → `CONST(device)`. Since our COPY is created inside `get_kernel_graph` (AFTER `transform_to_call`), this rule doesn't interfere. + +### Changes to `tensor.py:assign` + +Remove the commented-out hack. No other changes needed — the ASSIGN is created normally, and our rewrite rule handles DISK. + +One consideration: the current code relaxes device/dtype checks for DISK: +```python +if not is_disk and self.device != x.device: raise RuntimeError(...) +if not is_disk and self.dtype != x.dtype: raise RuntimeError(...) +``` +These relaxations should stay — DISK assign allows cross-device sources and (via bitcast) different dtypes. + +### Changes to `engine/allocations.py` + +`disk_copy_is_buffer` may need adjustment: currently it creates a buffer_map entry for ALL COPY-to-disk UOps. For COPYs created inside `get_kernel_graph`, there's no interference (they don't exist in the tensor graph). But if a COPY-to-disk appears in the tensor graph (created at the tensor level), the buffer_map entry could cause issues in `linear_to_schedule` (accessing `.buffer` on a COPY UOp fails). + +**Fix**: Either don't create buffer_map entries for COPYs that are inside ASSIGNs, or skip `.buffer` access for non-buffer UOps in `linear_to_schedule`. + +## Key Files to Modify + +| File | Change | +|------|--------| +| `tinygrad/tensor.py` | Remove the DISK assign hack (already commented out) | +| `tinygrad/schedule/rangeify.py` | Add `disk_assign_wrap_copy` rule to `earliest_rewrites` | + +## Testing + +Run the existing disk tests: +```bash +python -m pytest test/unit/test_disk_tensor.py -xvs +``` + +Key tests to verify: +- `test_assign_const_to_disk` — CONST source +- `test_assign_slice_from_const` — sliced CONST source +- `test_assign_disk_to_disk` — disk-to-disk via CPU +- `test_assign_slice` — slice assign +- `test_assign_to_different_dtype` — cross-dtype assign +- `test_assign_with_bitcast` — bitcast + assign (used by safe_save) +- `test_assign_to_bitcast_view` — assign to bitcast view +- `test_assign_cross_device` — cross-device assign + +Also test safe_save (the primary consumer): +```bash +python -m pytest test/unit/test_disk_tensor.py::TestSafetensors -xvs +``` diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ed9ef96f9b7ba..ebf044ba01887 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -11,7 +11,7 @@ # unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op) def _unwrap_src(s: UOp) -> UOp: - while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0] + while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0] return s def create_schedule(sched_sink:UOp) -> UOp: @@ -39,8 +39,8 @@ def create_schedule(sched_sink:UOp) -> UOp: assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}" children.setdefault(ss.src[1], []).append(k) in_degree[k] += 1 - case Ops.BUFFER | Ops.PARAM | Ops.BIND: - pass # BUFFER/PARAM is already realized, BIND is a bound variable (not a buffer dependency) + case Ops.BUFFER | Ops.BUFFER_VIEW | Ops.PARAM | Ops.BIND: + pass # BUFFER/BUFFER_VIEW/PARAM is already realized, BIND is a bound variable (not a buffer dependency) case _: raise RuntimeError(f"input to kernel must be AFTER, BUFFER, PARAM, MSELECT, MSTACK, or BIND, not {s.op}") diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 56a103ea783fb..b953487a9b673 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -19,8 +19,9 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None: def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp): # don't realize COPY/BUFFER_VIEW/ENCDEC when they are the direct source of ASSIGN — the ASSIGN target buffer is the output + is_disk = isinstance(buf.base._device, str) and buf.base._device.startswith(("DISK", "TINYFS")) if x.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC} and x in ctx \ - and not buf.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD): + and (is_disk or not buf.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD)): del ctx[x] # you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce if buf.base in x.backward_slice_with_self: ctx[x] = None diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 9174f851d0a4b..a33259ffbaeea 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -112,6 +112,13 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None: if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}") return c.src[0].substitute(dict_map, walk=True) +def disk_assign_wrap_copy(assign:UOp, target:UOp, src:UOp): + """For DISK assigns, wrap the source in a COPY so it becomes a COPY ExecItem instead of a kernel.""" + if target.base.op not in {Ops.BUFFER, Ops.PARAM, Ops.AFTER}: return None + if not (isinstance(target.base._device, str) and target.base._device.startswith("DISK")): return None + if src.op is Ops.COPY and src._device == target.base._device: return None + return assign.replace(src=(target, (src.contiguous() if src.base.op is Ops.CONST else src).copy_to_device(target.base._device))) + earliest_rewrites = mop_cleanup+PatternMatcher([ # early fixup const copy (UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))), @@ -162,6 +169,9 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None: # make source contiguous if it has hazardous movement ops on the dest buffer (UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard), + # for DISK assigns, wrap source in COPY so it becomes a COPY ExecItem instead of a kernel (DISK has no renderer) + (UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), disk_assign_wrap_copy), + # ** size 0 ** # reduce of size 0 is the identity element @@ -520,7 +530,16 @@ def split_store(x:UOp) -> UOp|None: if ret.op is Ops.STORE: stored = ret.src[1] elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1] else: raise RuntimeError(f"unknown kernel type {ret.op}") - if stored.op in {Ops.COPY, Ops.BUFFER_VIEW}: ret = stored.replace(src=stored.src + ret.ended_ranges) + if stored.op in {Ops.COPY, Ops.BUFFER_VIEW}: + # for COPY to DISK with offset, replace output buffer with BUFFER_VIEW so BufferCopy writes at correct offset + if stored.op is Ops.COPY: + idx = (ret.src[0] if ret.op is Ops.END else ret).src[0] + out_key, out_buf = next(iter(lctx.map.items())) + if idx.op is Ops.INDEX and len(idx.src) > 1 and isinstance(out_buf._device, str) and out_buf._device.startswith(("DISK", "TINYFS")): + offset, out_size = idx.src[1].vmin, idx.src[1].vmax - idx.src[1].vmin + 1 + if offset > 0 or out_buf.size != out_size: + lctx.map[out_key] = UOp(Ops.BUFFER_VIEW, out_buf.dtype, (out_buf,), (out_size, offset)) + ret = stored.replace(src=stored.src + ret.ended_ranges) elif stored.op is Ops.ENCDEC: ret = stored else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index fe9f1a8e93c3f..964e161795817 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -301,17 +301,30 @@ def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor: if not is_disk and self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}") if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}") - # TODO: this is a hack for writing to DISK. remove with working assign - if is_disk: - self._buffer().copyin(x._data()) + # cross-size BITCAST on disk can't be compiled as a kernel — handle eagerly + if is_disk and self.uop.base.op is Ops.BITCAST: + self._buffer().copyin(x.to("CPU")._data()) return self + # strided disk views (e.g. dt[::2]) can't be represented as a contiguous BUFFER_VIEW + if is_disk and not self.uop.has_buffer_identity() and self.uop.contiguous_view_offset() is None: + raise RuntimeError(f"cannot collapse movement ops on {self.uop.base.op} to a contiguous view") + # for disk view assigns, the disk file must exist before writing to a slice — realize the base tensor first + base = self.uop.base + if is_disk and not self.uop.has_buffer_identity() and base.op not in {Ops.BUFFER, Ops.AFTER}: + for tref in list(all_tensors): + if (t_root:=tref()) is not None and t_root.uop is base: + t_root.realize() + break + base = self.uop.base # NOTE: assign_uop is created before AFTER embedding (uses original self.uop), # but AFTER must be embedded before _apply_uop (so subsequent assigns see it) assign_uop = self.uop.assign(x.uop) - base = self.uop.base if base.op in {Ops.BUFFER, Ops.AFTER} and not self.uop.has_buffer_identity(): _apply_map_to_tensors({base: base.after(assign_uop)}, name="Embed View Assign", walk=True) - return self.replace(self._apply_uop(lambda *_: assign_uop, x)) + ret = self.replace(self._apply_uop(lambda *_: assign_uop, x)) + # disk assigns must be realized immediately — callers expect data to be written + if is_disk: ret.realize() + return ret def detach(self) -> Tensor: """ @@ -1338,6 +1351,8 @@ def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None: if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) self.assign(self._getitem(indices, v)) elif is_disk or self.uop.is_realized or self.uop.base.op is Ops.AFTER: # basic setitem, self is realized + # for disk tensors, realize the initial buffer first so the file exists before writing to a slice + if is_disk and not (self.uop.is_realized or self.uop.base.op is Ops.AFTER): self.realize() view = self[indices] if isinstance(v, Tensor) and v.uop.op is Ops.ASSIGN and v.uop in view.uop.base.src: return view.assign(v) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 829a3631f339d..13e21477d143e 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -649,7 +649,7 @@ def _device(self) -> str|tuple[str, ...]|None: return None @property def buf_uop(self) -> UOp: - if self.op in {Ops.BUFFER, Ops.PARAM}: return self + if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.PARAM}: return self if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg) if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.buf_uop for x in self.src)) if self.base.op is Ops.AFTER: return self.base.src[0].buf_uop.base