Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions tinygrad/engine/realize.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
Device[dest.device].synchronize()
return time.perf_counter() - st

class BufferCopyOffset(BufferCopy):
def __init__(self, total_sz:int, dest_device:str, src_device:str, dest_off_bytes:int):
super().__init__(total_sz, dest_device, src_device)
self.dest_off_bytes = dest_off_bytes
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
dest, src = rawbufs[0:2]
return super().__call__([dest.view(src.size, src.dtype, self.dest_off_bytes).ensure_allocated(), src], var_vals, wait=wait)

class BufferXfer(BufferCopy):
def copy(self, dest, src): dest.allocator._transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.dev, dest_dev=dest.allocator.dev)

Expand Down Expand Up @@ -123,8 +131,41 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:

# **************** lowering functions ****************

def _extract_store_copy_offset(sink:UOp) -> int|None:
if sink.op is Ops.PROGRAM:
if len(sink.src) == 0 or sink.src[0].op is not Ops.SINK: return None
sink = sink.src[0]
if sink.op is not Ops.SINK or len(sink.src) != 1: return None
snk_child = sink.src[0]
if snk_child.op is Ops.END and snk_child.src[0].op is Ops.STORE: store = snk_child.src[0]
elif snk_child.op is Ops.STORE: store = snk_child
else: return None
target, source = store.src
if target.op is not Ops.INDEX or source.op is not Ops.INDEX: return None
if target.src[0].op is not Ops.PARAM or source.src[0].op is not Ops.PARAM: return None

if len(target.src) != 2 or len(source.src) != 2: return None

def _base_plus_const(x:UOp) -> tuple[UOp|None, int]:
if x.op is Ops.CONST: return None, x.arg
if x.op is Ops.ADD and x.src[0].op is Ops.CONST: return x.src[1], x.src[0].arg
if x.op is Ops.ADD and x.src[1].op is Ops.CONST: return x.src[0], x.src[1].arg
return x, 0

target_base, target_off = _base_plus_const(target.src[1])
source_base, source_off = _base_plus_const(source.src[1])
if target_base is not source_base: return None
return target_off - source_off

def lower_disk_store_copy(ctx:list[Buffer|None], sink:UOp):
if len(ctx) < 2 or ctx[0] is None or ctx[1] is None: return None
if not (isinstance(ctx[0].device, str) and ctx[0].device.startswith(("DISK", "TINYFS"))): return None
if (dest_off:=_extract_store_copy_offset(sink)) is None: return None
return BufferCopyOffset(ctx[1].nbytes, ctx[0].device, ctx[1].device, dest_off*ctx[0].dtype.itemsize)

# NOTE: ctx is the buffers
si_lowerer = PatternMatcher([
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lower_disk_store_copy),
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
Expand Down
7 changes: 6 additions & 1 deletion tinygrad/nn/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,12 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=No
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:8].assign(Tensor([len(j)], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8))
t[8:8+len(j)].assign(list(j.encode('utf-8')))
for k,v in safe_load(t).items(): v.assign(tensors[k])
data_start = 8 + len(j)
for k,v in tensors.items():
st, en = headers[k]['data_offsets']
t[data_start+st:data_start+en].assign(v.to("CPU").bitcast(dtypes.uint8).flatten())
# flush pending view-assigns so the file is fully written before returning
t.realize()

# state dict

Expand Down
38 changes: 33 additions & 5 deletions tinygrad/schedule/rangeify.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,28 @@ def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS)):
return assign.replace(src=(target, src.contiguous()))

def rewrite_cross_device_disk_assign(assign:UOp, target:UOp, src:UOp):
if not (isinstance(target.device, str) and target.device.startswith(("DISK", "TINYFS"))): return None
# Slice/view/bitcast-target assigns need destination offset semantics. Copying RHS into a standalone DISK buffer
# loses that mapping and can also open the file with the wrong (smaller) size.
if target.op not in {Ops.BUFFER, Ops.COPY}: return assign.replace(src=(target, src.contiguous()))
if src.op is Ops.COPY and src.device == target.device: return None
# materialize source on its native device, then lower the write as a cross-device COPY into the assign target
return assign.replace(src=(target, src.contiguous().copy_to_device(target.device)))

def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp):
root_target = target
while root_target.op is Ops.ASSIGN: root_target = root_target.src[0]
# when RHS depends on the previous assign result, break with contiguous
if target in src.toposort(): src = src.contiguous()
return assign.replace(src=(root_target, src))

def move_bitcast_from_assign_target(target:UOp, src:UOp):
# Keep bitcast-on-target for DISK/TINYFS, so assign can lower to a raw copy without creating
# a byte-pack kernel (e.g. float->uchar), which some CPU renderers can't emit correctly.
if isinstance(target.device, str) and target.device.startswith(("DISK", "TINYFS")): return None
return target.assign(src.bitcast(target.dtype))

def split_reduceop(reduce:UOp, x:UOp):
if prod(reduce.shape) == 0: return None
if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))<getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return None
Expand Down Expand Up @@ -122,13 +137,16 @@ def resolve_call(c:UOp, allow_param_mismatch=False) -> UOp|None:

# move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype))
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))),
lambda target, src: target.assign(src.bitcast(target.dtype))),
move_bitcast_from_assign_target),

# if assign target is itself an ASSIGN chain, canonicalize to the original buffer target
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain),

# make source contiguous if it has hazardous movement ops on the dest buffer
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),

# lower cross-device assigns to DISK/TINYFS as COPYs (DISK has no renderer)
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), rewrite_cross_device_disk_assign),
])

# *****************
Expand Down Expand Up @@ -316,6 +334,9 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
if (assign := x.src[0]).op is Ops.ASSIGN:
assign_target, assign_src = assign.src[0], assign.src[1]
if assign_target.op is Ops.BITCAST and assign_target.src[0].op is Ops.INDEX and isinstance(assign_target.src[0].src[0]._device, str) \
and assign_target.src[0].src[0]._device.startswith(("DISK", "TINYFS")):
assign_target = assign_target.src[0].replace(dtype=assign_target.dtype)
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
while assign_src.op is Ops.NOOP: assign_src = assign_src.src[0]
# skip self-assign from same-device copy, otherwise create the store
Expand Down Expand Up @@ -460,18 +481,25 @@ def split_store(x:UOp) -> UOp|None:
# local kernel rewrite
lctx = LocalAddBufferContext()
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen, ctx=lctx, name="kernel split", bottom_up=True)
kernel_bufs = tuple(lctx.map.values())

# SINK requires all buffers on the same device, but COPY/BUFFER_VIEW/ENCDEC are cross-device or special hardware ops
if ret.op is Ops.STORE: stored = ret.src[1]
elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1]
if ret.op is Ops.STORE: stored, store = ret.src[1], ret
elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored, store = ret.src[0].src[1], ret.src[0]
else: raise RuntimeError(f"unknown kernel type {ret.op}")
dest_dev = None
if store.src[0].op is Ops.INDEX and store.src[0].src[0].op is Ops.PARAM and isinstance(store.src[0].src[0].arg, int) and store.src[0].src[0].arg < len(kernel_bufs):
dest_dev = kernel_bufs[store.src[0].src[0].arg]._device
is_disk_dest = isinstance(dest_dev, str) and dest_dev.startswith(("DISK", "TINYFS"))
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW}: ret = stored.replace(src=stored.src + ret.ended_ranges)
elif stored.op is Ops.ENCDEC: ret = stored
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))

kernel = ret.call(*lctx.map.values(), *lctx.vars.keys())
kernel = ret.call(*kernel_bufs, *lctx.vars.keys())
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
# DISK/TINYFS stores with direct source indexing are lowered to BufferCopyOffset in realize.py.
if not (is_disk_dest and stored.op is Ops.INDEX and stored.src[0].op in {Ops.BUFFER, Ops.PARAM}):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
return kernel

split_kernels = PatternMatcher([
Expand Down
4 changes: 0 additions & 4 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,6 @@ def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor:
if self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}")
if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}")

# TODO: this is a hack for writing to DISK. remove with working assign
if is_disk:
self._buffer().copyin(x._data())
return self
result = self._apply_uop(UOp.assign, x)
# track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read
if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity():
Expand Down
3 changes: 3 additions & 0 deletions tinygrad/uop/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,8 @@ def shard(self, devices:tuple[str, ...], axis:int) -> UOp: return self.copy_to_d
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
assert arg is None or isinstance(self.device, tuple)
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg)
if arg is None and isinstance(device, str) and device.startswith("DISK"):
return UOp.new_buffer(device, inp.shard_size, inp.dtype).reshape(inp.max_shard_shape).assign(inp)
return UOp(Ops.COPY, self.dtype, (inp, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device))
def mselect(self, arg:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), arg)
@property
Expand All @@ -559,6 +561,7 @@ def base(self) -> UOp:
if self.op in GroupOp.Movement: return self.src[0].base
if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW
if self.op is Ops.DETACH: return self.src[0].base # DETACH can't change base
if self.op is Ops.ASSIGN and self.src[0].op is not Ops.ASSIGN: return self.src[0].base
return self

@property
Expand Down
Loading