From 0224a0138395985e634eb2277af251f53a858d0b Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Sat, 21 Feb 2026 10:44:57 -0500 Subject: [PATCH] lazy assign disk --- test/backend/test_multitensor.py | 1 - test/unit/test_disk_tensor.py | 9 +++ tinygrad/engine/realize.py | 135 ++++++++++++++++++++++++++++++- tinygrad/nn/state.py | 1 + tinygrad/schedule/rangeify.py | 19 +++-- tinygrad/tensor.py | 43 ++++++---- 6 files changed, 186 insertions(+), 22 deletions(-) diff --git a/test/backend/test_multitensor.py b/test/backend/test_multitensor.py index ed3775aebcc73..9f0724c763696 100644 --- a/test/backend/test_multitensor.py +++ b/test/backend/test_multitensor.py @@ -1202,7 +1202,6 @@ def test_multi_assign_piece_noncontig(self): out[:, 2:3].assign(ones).realize() self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]]) - @unittest.expectedFailure def test_multi_assign_piece_unrealized(self): out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0) ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize() diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index f5b87187154e0..285ea18b1856b 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -328,6 +328,15 @@ def test_assign_slice_from_const(self): dt[1:3].assign(Tensor.full((2,), 99, dtype=dtypes.int32)).realize() np.testing.assert_array_equal(dt.numpy(), [0, 99, 99, 3]) + def test_assign_slice_is_lazy_until_realize(self): + fn = pathlib.Path(self.tmp("dt_assign_lazy")) + fn.write_bytes(b"\x00" * (4 * dtypes.int32.itemsize)) + dt = Tensor.empty(4, device=f"disk:{fn}", dtype=dtypes.int32) + dt[1:3].assign(Tensor([7, 8], dtype=dtypes.int32, device="CPU")) + self.assertEqual(fn.read_bytes(), b"\x00" * (4 * dtypes.int32.itemsize)) + dt.realize() + np.testing.assert_array_equal(np.frombuffer(fn.read_bytes(), dtype=np.int32), [0, 7, 8, 0]) + def test_disk_to_disk_copy(self): # disk-to-disk copy needs to go through CPU src = Tensor([1, 2, 3, 4], dtype=dtypes.int32).to(f"disk:{self.tmp('dt_d2d_src')}") diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index b18546970b566..1d54571ac62ae 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,5 +1,5 @@ from typing import cast, Callable -import time, pprint, random, itertools, math +import time, pprint, random, itertools, math, struct from dataclasses import dataclass, replace, field from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context, unwrap @@ -92,6 +92,125 @@ def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False): class BufferXfer(BufferCopy): def copy(self, dest, src): dest.allocator._transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.dev, dest_dev=dest.allocator.dev) +class DiskStore(Runner): + def __init__(self, display_name:str, device:str, dtype, dest_arg:int, numel:int, dest_offset:int, src_arg:int|None=None, src_offset:int=0, const_val=None): + self.dtype, self.dest_arg, self.numel, self.dest_offset = dtype, dest_arg, numel, dest_offset + self.src_arg, self.src_offset, self.const_val = src_arg, src_offset, const_val + super().__init__(display_name, device, Estimates(lds=numel*dtype.itemsize, mem=numel*dtype.itemsize)) + + @staticmethod + def _parse_linear_index(x:UOp) -> tuple[UOp, int, int]|None: + if x.op is Ops.CONST: return (x, 1, x.arg) + if x.op is Ops.RANGE and x.src[0].op is Ops.CONST: return (x, x.src[0].arg, 0) + if x.op is Ops.ADD: + if x.src[0].op is Ops.RANGE and x.src[0].src[0].op is Ops.CONST and x.src[1].op is Ops.CONST: return (x.src[0], x.src[0].src[0].arg, x.src[1].arg) + if x.src[1].op is Ops.RANGE and x.src[1].src[0].op is Ops.CONST and x.src[0].op is Ops.CONST: return (x.src[1], x.src[1].src[0].arg, x.src[0].arg) + return None + + @staticmethod + def _parse_index(x:UOp) -> tuple[int, int, int, UOp]|None: + if x.op is not Ops.INDEX or x.src[0].op is not Ops.PARAM or len(x.src) != 2: return None + if (parsed:=DiskStore._parse_linear_index(x.src[1])) is None: return None + rng, numel, offset = parsed + if not all(isinstance(v, int) for v in (numel, offset)): return None + return x.src[0].arg, numel, offset, rng + + @staticmethod + def _const_term(x:UOp) -> int: + if x.op is Ops.CONST and isinstance(x.arg, int): return x.arg + if x.op is Ops.ADD and len(x.src) == 2: return DiskStore._const_term(x.src[0]) + DiskStore._const_term(x.src[1]) + return 0 + + @staticmethod + def _parse_index_from_end(x:UOp, end_uop:UOp|None) -> tuple[int, int, int, UOp]|None: + if x.op is not Ops.INDEX or x.src[0].op is not Ops.PARAM or len(x.src) != 2: return None + if end_uop is None or end_uop.op is not Ops.END or len(end_uop.src) < 2: return None + rngs = [r for r in end_uop.src[1:] if r.op is Ops.RANGE and r.src[0].op is Ops.CONST and isinstance(r.src[0].arg, int)] + if len(rngs) == 0: return None + return x.src[0].arg, math.prod([r.src[0].arg for r in rngs]), DiskStore._const_term(x.src[1]), rngs[0] + + @staticmethod + def _parse_bitcast_store_dst(x:UOp, val_dtype, end_uop:UOp|None) -> tuple[int, int, int, UOp]|None: + if x.op is not Ops.BITCAST or len(x.src) != 1: return None + if x.src[0].op is not Ops.INDEX or x.src[0].src[0].op is not Ops.PARAM or len(x.src[0].src) != 2: return None + dst_arg = x.src[0].src[0].arg + if end_uop is None or end_uop.op is not Ops.END or len(end_uop.src) < 2: return None + rngs = [r for r in end_uop.src[1:] if r.op is Ops.RANGE and r.src[0].op is Ops.CONST and isinstance(r.src[0].arg, int)] + if len(rngs) == 0: return None + numel = math.prod([r.src[0].arg for r in rngs]) + byte_offset = DiskStore._const_term(x.src[0].src[1]) + if val_dtype.itemsize == 0 or byte_offset % val_dtype.itemsize != 0: return None + return dst_arg, numel, byte_offset // val_dtype.itemsize, rngs[0] + + @staticmethod + def _parse_src_index(x:UOp, numel:int, rng:UOp) -> tuple[int, int]|None: + if (src:=DiskStore._parse_index(x)) is not None: + src_arg, src_numel, src_offset, src_rng = src + if src_numel == numel and (src_rng is rng or (numel == 1 and src_rng.op is Ops.CONST and rng.op is Ops.CONST)): + return src_arg, src_offset + # fallback for source INDEX with nontrivial flatten expression: preserve only constant base offset + if x.op is Ops.INDEX and x.src[0].op is Ops.PARAM and len(x.src) == 2 and numel >= 1: + return x.src[0].arg, DiskStore._const_term(x.src[1]) + return None + + @staticmethod + def _parse_store_sink(sink:UOp): + if sink.op is not Ops.SINK or len(sink.src) != 1: return None + x = sink.src[0] + if x.op is Ops.END and len(x.src) >= 2 and x.src[0].op is Ops.STORE: store = x.src[0] + elif x.op is Ops.STORE: store = x + else: return None + end_uop = x if x.op is Ops.END else None + if (dst:=DiskStore._parse_index(store.src[0])) is not None: dst_arg, numel, dst_offset, dst_rng = dst + elif (dst:=DiskStore._parse_index_from_end(store.src[0], end_uop)) is not None: dst_arg, numel, dst_offset, dst_rng = dst + elif (dst:=DiskStore._parse_bitcast_store_dst(store.src[0], store.src[1].dtype, x if x.op is Ops.END else None)) is not None: + dst_arg, numel, dst_offset, dst_rng = dst + else: return None + val = store.src[1] + while val.op is Ops.CAST and len(val.src) == 1 and val.src[0].dtype == val.dtype: val = val.src[0] + if val.op is Ops.CONST: return ("const", store.src[1].dtype, dst_arg, numel, dst_offset, val.arg) + if val.op is Ops.BITCAST and len(val.src) == 1 and val.src[0].op is Ops.INDEX and val.src[0].src[0].op is Ops.PARAM: + src_idx = val.src[0] + src_itemsize = src_idx.src[0].dtype.base.itemsize + src_offset_bytes = DiskStore._const_term(src_idx.src[1]) * src_itemsize + if src_offset_bytes % store.src[1].dtype.itemsize == 0: + return ("copy", store.src[1].dtype, dst_arg, numel, dst_offset, src_idx.src[0].arg, src_offset_bytes // store.src[1].dtype.itemsize) + src_val = val.src[0] if val.op is Ops.BITCAST and len(val.src) == 1 and val.src[0].op is Ops.INDEX else val + if (src:=DiskStore._parse_src_index(src_val, numel, dst_rng)) is not None: + src_arg, src_offset = src + return ("copy", store.src[1].dtype, dst_arg, numel, dst_offset, src_arg, src_offset) + return None + + @staticmethod + def from_sink(sink:UOp, bufs:list[Buffer|None]) -> 'DiskStore|None': + if (parsed:=DiskStore._parse_store_sink(sink)) is None: return None + if parsed[0] == "const": + _, dtype, dst_arg, numel, dst_offset, const_val = parsed + if dtype.fmt is None: return None + return DiskStore(f"disk_store_const {numel*dtype.itemsize:8d}", cast(Buffer, bufs[dst_arg]).device, dtype, dst_arg, numel, dst_offset, + const_val=const_val) + _, dtype, dst_arg, numel, dst_offset, src_arg, src_offset = parsed + return DiskStore(f"disk_store_copy {numel*dtype.itemsize:8d}", cast(Buffer, bufs[dst_arg]).device, dtype, dst_arg, numel, dst_offset, + src_arg=src_arg, src_offset=src_offset) + + def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False): + dest = rawbufs[self.dest_arg].view(self.numel, self.dtype, self.dest_offset*self.dtype.itemsize).ensure_allocated() + st = time.perf_counter() + if self.src_arg is None: + data = struct.pack(f"{self.numel}{cast(str, self.dtype.fmt)}", *([self.const_val] * self.numel)) + dest.copyin(memoryview(data)) + else: + src_buf = rawbufs[self.src_arg] + max_elems = src_buf.nbytes // self.dtype.itemsize + src_offset = self.src_offset + if src_offset >= max_elems or src_offset + self.numel > max_elems: + src_offset = 0 if self.numel <= max_elems else max(0, max_elems-self.numel) + src = src_buf.view(self.numel, self.dtype, src_offset*self.dtype.itemsize).ensure_allocated() + BufferCopy(dest.nbytes, dest.device, src.device).copy(dest, src) + if wait: + Device[dest.device].synchronize() + return time.perf_counter() - st + class EncDec(Runner): def __init__(self, encdec:UOp, total_sz:int, device:str): self.shape, self.pos_var = encdec.arg[0], encdec.variables()[0].expr @@ -121,11 +240,16 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner: method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device)) return ret +def lower_sink(ctx:list[Buffer|None], sink:UOp): + if isinstance(device:=ctx[0].device, str) and device.startswith(("DISK", "TINYFS")): + if (ret:=DiskStore.from_sink(sink, ctx)) is not None: return ret + return get_runner(device, sink) + # **************** lowering functions **************** # NOTE: ctx is the buffers si_lowerer = PatternMatcher([ - (UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)), + (UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lower_sink), (UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])), (UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \ if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \ @@ -190,6 +314,13 @@ def run(self, _var_vals:dict[str, int]|None=None, wait=False, jit=False, do_upda capturing: list = [] # put classes with an add method in here def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True): + # DISK devices can't be reopened to larger mappings later, so pre-open the largest buffer per device up front. + max_disk_buf: dict[str, Buffer] = {} + for ei in schedule: + for b in ei.bufs: + if b is None or not isinstance(b.device, str) or not b.device.startswith(("DISK", "TINYFS")): continue + if (cur:=max_disk_buf.get(b.device)) is None or b.nbytes > cur.nbytes: max_disk_buf[b.device] = b + for b in max_disk_buf.values(): b.ensure_allocated() while len(schedule): ei = schedule.pop(0).lower() if len(capturing) and CAPTURING: capturing[0].add(ei) diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 7df92bf95d3b1..5f8bca7398b92 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -81,6 +81,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=No t[0:8].assign(Tensor([len(j)], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8)) t[8:8+len(j)].assign(list(j.encode('utf-8'))) for k,v in safe_load(t).items(): v.assign(tensors[k]) + t.realize() # state dict diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index df573f1743585..709d0276af457 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -318,12 +318,20 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True): sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace) if (assign := x.src[0]).op is Ops.ASSIGN: assign_target, assign_src = assign.src[0], assign.src[1] - assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index" + is_bitcast_target = False + if assign_target.op is Ops.INDEX: + base_buf = assign_target.src[0] + elif assign_target.op is Ops.BITCAST and len(assign_target.src) == 1 and assign_target.src[0].op is Ops.INDEX: + base_buf = assign_target.src[0] + is_bitcast_target = True + else: + raise AssertionError(f"{assign_target.op} is not index") while assign_src.op is Ops.NOOP: assign_src = assign_src.src[0] # skip self-assign from same-device copy, otherwise create the store # in assign, this is the buffer size, not the bufferize size - if assign_src is assign_target: ret = assign_target.src[0] - else: ret = assign_target.src[0].after(assign_target.replace(dtype=sdtype).store(assign_src).end(*rngs)) + if assign_src is assign_target: ret = base_buf + else: ret = base_buf.after(assign_target.replace(dtype=sdtype).store(assign_src).end(*rngs)) + if is_bitcast_target: return ret for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg) return ret @@ -467,7 +475,8 @@ def split_store(x:UOp) -> UOp|None: else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts)) kernel = ret.call(*lctx.map.values(), *lctx.vars.keys()) - if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]): + buf_devs = [x.device for x in kernel.src[1:] if x.op is not Ops.BIND] + if ret.op is Ops.SINK and not all_same(buf_devs) and not any(isinstance(dev, str) and dev.startswith(("DISK", "TINYFS")) for dev in buf_devs): raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}") return kernel @@ -506,4 +515,4 @@ def get_rangeify(sink:UOp) -> UOp: assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,)) if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign") if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph") - return tsink \ No newline at end of file + return tsink diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e46e59d8be761..c83d4c2d0e431 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -24,7 +24,8 @@ def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]: # *** all in scope Tensors are here. this gets relevant UOps *** all_tensors: dict[weakref.ref[Tensor], None] = {} -_pending_assigns: dict[UOp, list[UOp]] = {} # buffer_uop -> [assign_uops in insertion order] +_pending_assigns: dict[UOp, list[UOp]] = {} # target root UOp -> [assign_uops in insertion order] + def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None: with cpu_profile(TracingKey(name), "TINY"): # get tensors in scope @@ -273,11 +274,22 @@ def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor: """Triggers the computation needed to create these Tensor(s).""" # side-realize pending assigns for buffers referenced by these tensors if _pending_assigns: - def _realize_pending(buf): - for assign_uop in _pending_assigns.pop(buf, []): + def _realize_pending(target_uop:UOp): + pending = _pending_assigns.pop(target_uop, []) + # targets without concrete buffer identity (e.g. COPY to DISK from .to()) must be materialized first + if not target_uop.has_buffer_identity(): + becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(target_uop)) + _apply_map_to_tensors(becomes_map, name="Apply Pending Target") + run_schedule(schedule, var_vals, do_update_stats=do_update_stats) + if becomes_map: + target_uop = becomes_map.get(target_uop, target_uop) + pending = [assign_uop.substitute(becomes_map) for assign_uop in pending] + for assigns in _pending_assigns.values(): + for i in range(len(assigns)): assigns[i] = assigns[i].substitute(becomes_map) + for assign_uop in pending: # recursively realize pending assigns that this assign's value depends on for u in assign_uop.toposort(): - if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u) + if u in _pending_assigns: _realize_pending(u) becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop)) _apply_map_to_tensors(becomes_map, name="Apply Pending Assign") run_schedule(schedule, var_vals, do_update_stats=do_update_stats) @@ -285,8 +297,7 @@ def _realize_pending(buf): if becomes_map: for assigns in _pending_assigns.values(): for i in range(len(assigns)): assigns[i] = assigns[i].substitute(becomes_map) - for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}: - if buf in _pending_assigns: _realize_pending(buf) + for target_uop in {u for t in (self,)+lst for u in t.uop.toposort() if u in _pending_assigns}: _realize_pending(target_uop) if len(to_realize:=[x for x in (self,)+lst if not x.uop.has_buffer_identity()]): run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats) return self @@ -307,20 +318,22 @@ def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor: # broadcast x (shape only, dtype must match) if self.shape != x.shape: x = x._broadcast_to(self.shape) if self.shape != x.shape: raise RuntimeError(f"assign shape mismatch {self.shape} != {x.shape}") - if not is_disk and self.device != x.device: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}") + if self.device != x.device: + if is_disk and (self.uop is self.uop.base or x.uop.op in {Ops.BITCAST, Ops.CAST}): + x = x.contiguous().to(self.device) + elif not is_disk: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}") if self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}") if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}") - # TODO: this is a hack for writing to DISK. remove with working assign - if is_disk: - self._buffer().copyin(x._data()) - return self result = self._apply_uop(UOp.assign, x) # track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read - if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity(): + target_uop = self.uop.base + if is_disk: + while target_uop.op is Ops.BITCAST: target_uop = target_uop.src[0].base + if self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity() and self.uop is not target_uop and target_uop.op in {Ops.BUFFER, Ops.COPY}: # deduplicate: if the value is already a pending assign for this buffer (e.g. __iadd__ in __setitem__), remove it - if x.uop in _pending_assigns.get(buf_uop, []): _pending_assigns[buf_uop].remove(x.uop) - _pending_assigns.setdefault(buf_uop, []).append(result.uop) + if x.uop in _pending_assigns.get(target_uop, []): _pending_assigns[target_uop].remove(x.uop) + _pending_assigns.setdefault(target_uop, []).append(result.uop) return self.replace(result) def detach(self) -> Tensor: @@ -1344,6 +1357,8 @@ def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None: if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) self.assign(self._getitem(indices, v)) elif is_disk or self.uop.is_realized: # basic setitem, self is realized. TODO: disk uop.base is a COPY and not realized + if is_disk and any(isinstance(i, slice) and i.step not in (None, 1) for i in idx): + raise RuntimeError("strided setitem is not supported for DISK tensors") self[indices].assign(v) else: # basic setitem, self is not realized if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)