From a3bfc656809b1902553881c9992fc4241cb6da29 Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Tue, 24 Feb 2026 20:15:54 -0500 Subject: [PATCH] disk assign change --- tinygrad/engine/realize.py | 41 +++++++++++++++++++++++++++++++++++ tinygrad/nn/state.py | 7 +++++- tinygrad/schedule/rangeify.py | 38 +++++++++++++++++++++++++++----- tinygrad/tensor.py | 4 ---- tinygrad/uop/ops.py | 3 +++ 5 files changed, 83 insertions(+), 10 deletions(-) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index b18546970b566..e8754d2e44f2d 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -89,6 +89,14 @@ def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False): Device[dest.device].synchronize() return time.perf_counter() - st +class BufferCopyOffset(BufferCopy): + def __init__(self, total_sz:int, dest_device:str, src_device:str, dest_off_bytes:int): + super().__init__(total_sz, dest_device, src_device) + self.dest_off_bytes = dest_off_bytes + def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False): + dest, src = rawbufs[0:2] + return super().__call__([dest.view(src.size, src.dtype, self.dest_off_bytes).ensure_allocated(), src], var_vals, wait=wait) + class BufferXfer(BufferCopy): def copy(self, dest, src): dest.allocator._transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.dev, dest_dev=dest.allocator.dev) @@ -123,8 +131,41 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner: # **************** lowering functions **************** +def _extract_store_copy_offset(sink:UOp) -> int|None: + if sink.op is Ops.PROGRAM: + if len(sink.src) == 0 or sink.src[0].op is not Ops.SINK: return None + sink = sink.src[0] + if sink.op is not Ops.SINK or len(sink.src) != 1: return None + snk_child = sink.src[0] + if snk_child.op is Ops.END and snk_child.src[0].op is Ops.STORE: store = snk_child.src[0] + elif snk_child.op is Ops.STORE: store = snk_child + else: return None + target, source = store.src + if target.op is not Ops.INDEX or source.op is not Ops.INDEX: return None + if target.src[0].op is not Ops.PARAM or source.src[0].op is not Ops.PARAM: return None + + if len(target.src) != 2 or len(source.src) != 2: return None + + def _base_plus_const(x:UOp) -> tuple[UOp|None, int]: + if x.op is Ops.CONST: return None, x.arg + if x.op is Ops.ADD and x.src[0].op is Ops.CONST: return x.src[1], x.src[0].arg + if x.op is Ops.ADD and x.src[1].op is Ops.CONST: return x.src[0], x.src[1].arg + return x, 0 + + target_base, target_off = _base_plus_const(target.src[1]) + source_base, source_off = _base_plus_const(source.src[1]) + if target_base is not source_base: return None + return target_off - source_off + +def lower_disk_store_copy(ctx:list[Buffer|None], sink:UOp): + if len(ctx) < 2 or ctx[0] is None or ctx[1] is None: return None + if not (isinstance(ctx[0].device, str) and ctx[0].device.startswith(("DISK", "TINYFS"))): return None + if (dest_off:=_extract_store_copy_offset(sink)) is None: return None + return BufferCopyOffset(ctx[1].nbytes, ctx[0].device, ctx[1].device, dest_off*ctx[0].dtype.itemsize) + # NOTE: ctx is the buffers si_lowerer = PatternMatcher([ + (UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lower_disk_store_copy), (UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)), (UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])), (UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \ diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 7df92bf95d3b1..280145e740229 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -80,7 +80,12 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=No t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}") t[0:8].assign(Tensor([len(j)], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8)) t[8:8+len(j)].assign(list(j.encode('utf-8'))) - for k,v in safe_load(t).items(): v.assign(tensors[k]) + data_start = 8 + len(j) + for k,v in tensors.items(): + st, en = headers[k]['data_offsets'] + t[data_start+st:data_start+en].assign(v.to("CPU").bitcast(dtypes.uint8).flatten()) + # flush pending view-assigns so the file is fully written before returning + t.realize() # state dict diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 7c07e41e844ca..b642557884eaf 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -40,6 +40,15 @@ def fix_assign_hazard(assign:UOp, target:UOp, src:UOp): if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS)): return assign.replace(src=(target, src.contiguous())) +def rewrite_cross_device_disk_assign(assign:UOp, target:UOp, src:UOp): + if not (isinstance(target.device, str) and target.device.startswith(("DISK", "TINYFS"))): return None + # Slice/view/bitcast-target assigns need destination offset semantics. Copying RHS into a standalone DISK buffer + # loses that mapping and can also open the file with the wrong (smaller) size. + if target.op not in {Ops.BUFFER, Ops.COPY}: return assign.replace(src=(target, src.contiguous())) + if src.op is Ops.COPY and src.device == target.device: return None + # materialize source on its native device, then lower the write as a cross-device COPY into the assign target + return assign.replace(src=(target, src.contiguous().copy_to_device(target.device))) + def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp): root_target = target while root_target.op is Ops.ASSIGN: root_target = root_target.src[0] @@ -47,6 +56,12 @@ def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp): if target in src.toposort(): src = src.contiguous() return assign.replace(src=(root_target, src)) +def move_bitcast_from_assign_target(target:UOp, src:UOp): + # Keep bitcast-on-target for DISK/TINYFS, so assign can lower to a raw copy without creating + # a byte-pack kernel (e.g. float->uchar), which some CPU renderers can't emit correctly. + if isinstance(target.device, str) and target.device.startswith(("DISK", "TINYFS")): return None + return target.assign(src.bitcast(target.dtype)) + def split_reduceop(reduce:UOp, x:UOp): if prod(reduce.shape) == 0: return None if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape)) UOp|None: # move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype)) (UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))), - lambda target, src: target.assign(src.bitcast(target.dtype))), + move_bitcast_from_assign_target), # if assign target is itself an ASSIGN chain, canonicalize to the original buffer target (UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain), # make source contiguous if it has hazardous movement ops on the dest buffer (UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard), + + # lower cross-device assigns to DISK/TINYFS as COPYs (DISK has no renderer) + (UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), rewrite_cross_device_disk_assign), ]) # ***************** @@ -316,6 +334,9 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True): sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace) if (assign := x.src[0]).op is Ops.ASSIGN: assign_target, assign_src = assign.src[0], assign.src[1] + if assign_target.op is Ops.BITCAST and assign_target.src[0].op is Ops.INDEX and isinstance(assign_target.src[0].src[0]._device, str) \ + and assign_target.src[0].src[0]._device.startswith(("DISK", "TINYFS")): + assign_target = assign_target.src[0].replace(dtype=assign_target.dtype) assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index" while assign_src.op is Ops.NOOP: assign_src = assign_src.src[0] # skip self-assign from same-device copy, otherwise create the store @@ -460,18 +481,25 @@ def split_store(x:UOp) -> UOp|None: # local kernel rewrite lctx = LocalAddBufferContext() ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen, ctx=lctx, name="kernel split", bottom_up=True) + kernel_bufs = tuple(lctx.map.values()) # SINK requires all buffers on the same device, but COPY/BUFFER_VIEW/ENCDEC are cross-device or special hardware ops - if ret.op is Ops.STORE: stored = ret.src[1] - elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1] + if ret.op is Ops.STORE: stored, store = ret.src[1], ret + elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored, store = ret.src[0].src[1], ret.src[0] else: raise RuntimeError(f"unknown kernel type {ret.op}") + dest_dev = None + if store.src[0].op is Ops.INDEX and store.src[0].src[0].op is Ops.PARAM and isinstance(store.src[0].src[0].arg, int) and store.src[0].src[0].arg < len(kernel_bufs): + dest_dev = kernel_bufs[store.src[0].src[0].arg]._device + is_disk_dest = isinstance(dest_dev, str) and dest_dev.startswith(("DISK", "TINYFS")) if stored.op in {Ops.COPY, Ops.BUFFER_VIEW}: ret = stored.replace(src=stored.src + ret.ended_ranges) elif stored.op is Ops.ENCDEC: ret = stored else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts)) - kernel = ret.call(*lctx.map.values(), *lctx.vars.keys()) + kernel = ret.call(*kernel_bufs, *lctx.vars.keys()) if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]): - raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}") + # DISK/TINYFS stores with direct source indexing are lowered to BufferCopyOffset in realize.py. + if not (is_disk_dest and stored.op is Ops.INDEX and stored.src[0].op in {Ops.BUFFER, Ops.PARAM}): + raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}") return kernel split_kernels = PatternMatcher([ diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 924769b655452..519452485cefb 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -319,10 +319,6 @@ def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor: if self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}") if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}") - # TODO: this is a hack for writing to DISK. remove with working assign - if is_disk: - self._buffer().copyin(x._data()) - return self result = self._apply_uop(UOp.assign, x) # track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity(): diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index d7603b5e5d0c9..572da461ca913 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -546,6 +546,8 @@ def shard(self, devices:tuple[str, ...], axis:int) -> UOp: return self.copy_to_d def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None): assert arg is None or isinstance(self.device, tuple) inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg) + if arg is None and isinstance(device, str) and device.startswith("DISK"): + return UOp.new_buffer(device, inp.shard_size, inp.dtype).reshape(inp.max_shard_shape).assign(inp) return UOp(Ops.COPY, self.dtype, (inp, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device)) def mselect(self, arg:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), arg) @property @@ -559,6 +561,7 @@ def base(self) -> UOp: if self.op in GroupOp.Movement: return self.src[0].base if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW if self.op is Ops.DETACH: return self.src[0].base # DETACH can't change base + if self.op is Ops.ASSIGN and self.src[0].op is not Ops.ASSIGN: return self.src[0].base return self @property