Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/backend/test_multitensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,6 @@ def test_multi_assign_piece_noncontig(self):
out[:, 2:3].assign(ones).realize()
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])

@unittest.expectedFailure
def test_multi_assign_piece_unrealized(self):
out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0)
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
Expand Down
9 changes: 9 additions & 0 deletions test/unit/test_disk_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,15 @@ def test_assign_slice_from_const(self):
dt[1:3].assign(Tensor.full((2,), 99, dtype=dtypes.int32)).realize()
np.testing.assert_array_equal(dt.numpy(), [0, 99, 99, 3])

def test_assign_slice_is_lazy_until_realize(self):
fn = pathlib.Path(self.tmp("dt_assign_lazy"))
fn.write_bytes(b"\x00" * (4 * dtypes.int32.itemsize))
dt = Tensor.empty(4, device=f"disk:{fn}", dtype=dtypes.int32)
dt[1:3].assign(Tensor([7, 8], dtype=dtypes.int32, device="CPU"))
self.assertEqual(fn.read_bytes(), b"\x00" * (4 * dtypes.int32.itemsize))
dt.realize()
np.testing.assert_array_equal(np.frombuffer(fn.read_bytes(), dtype=np.int32), [0, 7, 8, 0])

def test_disk_to_disk_copy(self):
# disk-to-disk copy needs to go through CPU
src = Tensor([1, 2, 3, 4], dtype=dtypes.int32).to(f"disk:{self.tmp('dt_d2d_src')}")
Expand Down
135 changes: 133 additions & 2 deletions tinygrad/engine/realize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import cast, Callable
import time, pprint, random, itertools, math
import time, pprint, random, itertools, math, struct
from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context, unwrap
Expand Down Expand Up @@ -92,6 +92,125 @@ def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
class BufferXfer(BufferCopy):
def copy(self, dest, src): dest.allocator._transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.dev, dest_dev=dest.allocator.dev)

class DiskStore(Runner):
def __init__(self, display_name:str, device:str, dtype, dest_arg:int, numel:int, dest_offset:int, src_arg:int|None=None, src_offset:int=0, const_val=None):
self.dtype, self.dest_arg, self.numel, self.dest_offset = dtype, dest_arg, numel, dest_offset
self.src_arg, self.src_offset, self.const_val = src_arg, src_offset, const_val
super().__init__(display_name, device, Estimates(lds=numel*dtype.itemsize, mem=numel*dtype.itemsize))

@staticmethod
def _parse_linear_index(x:UOp) -> tuple[UOp, int, int]|None:
if x.op is Ops.CONST: return (x, 1, x.arg)
if x.op is Ops.RANGE and x.src[0].op is Ops.CONST: return (x, x.src[0].arg, 0)
if x.op is Ops.ADD:
if x.src[0].op is Ops.RANGE and x.src[0].src[0].op is Ops.CONST and x.src[1].op is Ops.CONST: return (x.src[0], x.src[0].src[0].arg, x.src[1].arg)
if x.src[1].op is Ops.RANGE and x.src[1].src[0].op is Ops.CONST and x.src[0].op is Ops.CONST: return (x.src[1], x.src[1].src[0].arg, x.src[0].arg)
return None

@staticmethod
def _parse_index(x:UOp) -> tuple[int, int, int, UOp]|None:
if x.op is not Ops.INDEX or x.src[0].op is not Ops.PARAM or len(x.src) != 2: return None
if (parsed:=DiskStore._parse_linear_index(x.src[1])) is None: return None
rng, numel, offset = parsed
if not all(isinstance(v, int) for v in (numel, offset)): return None
return x.src[0].arg, numel, offset, rng

@staticmethod
def _const_term(x:UOp) -> int:
if x.op is Ops.CONST and isinstance(x.arg, int): return x.arg
if x.op is Ops.ADD and len(x.src) == 2: return DiskStore._const_term(x.src[0]) + DiskStore._const_term(x.src[1])
return 0

@staticmethod
def _parse_index_from_end(x:UOp, end_uop:UOp|None) -> tuple[int, int, int, UOp]|None:
if x.op is not Ops.INDEX or x.src[0].op is not Ops.PARAM or len(x.src) != 2: return None
if end_uop is None or end_uop.op is not Ops.END or len(end_uop.src) < 2: return None
rngs = [r for r in end_uop.src[1:] if r.op is Ops.RANGE and r.src[0].op is Ops.CONST and isinstance(r.src[0].arg, int)]
if len(rngs) == 0: return None
return x.src[0].arg, math.prod([r.src[0].arg for r in rngs]), DiskStore._const_term(x.src[1]), rngs[0]

@staticmethod
def _parse_bitcast_store_dst(x:UOp, val_dtype, end_uop:UOp|None) -> tuple[int, int, int, UOp]|None:
if x.op is not Ops.BITCAST or len(x.src) != 1: return None
if x.src[0].op is not Ops.INDEX or x.src[0].src[0].op is not Ops.PARAM or len(x.src[0].src) != 2: return None
dst_arg = x.src[0].src[0].arg
if end_uop is None or end_uop.op is not Ops.END or len(end_uop.src) < 2: return None
rngs = [r for r in end_uop.src[1:] if r.op is Ops.RANGE and r.src[0].op is Ops.CONST and isinstance(r.src[0].arg, int)]
if len(rngs) == 0: return None
numel = math.prod([r.src[0].arg for r in rngs])
byte_offset = DiskStore._const_term(x.src[0].src[1])
if val_dtype.itemsize == 0 or byte_offset % val_dtype.itemsize != 0: return None
return dst_arg, numel, byte_offset // val_dtype.itemsize, rngs[0]

@staticmethod
def _parse_src_index(x:UOp, numel:int, rng:UOp) -> tuple[int, int]|None:
if (src:=DiskStore._parse_index(x)) is not None:
src_arg, src_numel, src_offset, src_rng = src
if src_numel == numel and (src_rng is rng or (numel == 1 and src_rng.op is Ops.CONST and rng.op is Ops.CONST)):
return src_arg, src_offset
# fallback for source INDEX with nontrivial flatten expression: preserve only constant base offset
if x.op is Ops.INDEX and x.src[0].op is Ops.PARAM and len(x.src) == 2 and numel >= 1:
return x.src[0].arg, DiskStore._const_term(x.src[1])
return None

@staticmethod
def _parse_store_sink(sink:UOp):
if sink.op is not Ops.SINK or len(sink.src) != 1: return None
x = sink.src[0]
if x.op is Ops.END and len(x.src) >= 2 and x.src[0].op is Ops.STORE: store = x.src[0]
elif x.op is Ops.STORE: store = x
else: return None
end_uop = x if x.op is Ops.END else None
if (dst:=DiskStore._parse_index(store.src[0])) is not None: dst_arg, numel, dst_offset, dst_rng = dst
elif (dst:=DiskStore._parse_index_from_end(store.src[0], end_uop)) is not None: dst_arg, numel, dst_offset, dst_rng = dst
elif (dst:=DiskStore._parse_bitcast_store_dst(store.src[0], store.src[1].dtype, x if x.op is Ops.END else None)) is not None:
dst_arg, numel, dst_offset, dst_rng = dst
else: return None
val = store.src[1]
while val.op is Ops.CAST and len(val.src) == 1 and val.src[0].dtype == val.dtype: val = val.src[0]
if val.op is Ops.CONST: return ("const", store.src[1].dtype, dst_arg, numel, dst_offset, val.arg)
if val.op is Ops.BITCAST and len(val.src) == 1 and val.src[0].op is Ops.INDEX and val.src[0].src[0].op is Ops.PARAM:
src_idx = val.src[0]
src_itemsize = src_idx.src[0].dtype.base.itemsize
src_offset_bytes = DiskStore._const_term(src_idx.src[1]) * src_itemsize
if src_offset_bytes % store.src[1].dtype.itemsize == 0:
return ("copy", store.src[1].dtype, dst_arg, numel, dst_offset, src_idx.src[0].arg, src_offset_bytes // store.src[1].dtype.itemsize)
src_val = val.src[0] if val.op is Ops.BITCAST and len(val.src) == 1 and val.src[0].op is Ops.INDEX else val
if (src:=DiskStore._parse_src_index(src_val, numel, dst_rng)) is not None:
src_arg, src_offset = src
return ("copy", store.src[1].dtype, dst_arg, numel, dst_offset, src_arg, src_offset)
return None

@staticmethod
def from_sink(sink:UOp, bufs:list[Buffer|None]) -> 'DiskStore|None':
if (parsed:=DiskStore._parse_store_sink(sink)) is None: return None
if parsed[0] == "const":
_, dtype, dst_arg, numel, dst_offset, const_val = parsed
if dtype.fmt is None: return None
return DiskStore(f"disk_store_const {numel*dtype.itemsize:8d}", cast(Buffer, bufs[dst_arg]).device, dtype, dst_arg, numel, dst_offset,
const_val=const_val)
_, dtype, dst_arg, numel, dst_offset, src_arg, src_offset = parsed
return DiskStore(f"disk_store_copy {numel*dtype.itemsize:8d}", cast(Buffer, bufs[dst_arg]).device, dtype, dst_arg, numel, dst_offset,
src_arg=src_arg, src_offset=src_offset)

def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
dest = rawbufs[self.dest_arg].view(self.numel, self.dtype, self.dest_offset*self.dtype.itemsize).ensure_allocated()
st = time.perf_counter()
if self.src_arg is None:
data = struct.pack(f"{self.numel}{cast(str, self.dtype.fmt)}", *([self.const_val] * self.numel))
dest.copyin(memoryview(data))
else:
src_buf = rawbufs[self.src_arg]
max_elems = src_buf.nbytes // self.dtype.itemsize
src_offset = self.src_offset
if src_offset >= max_elems or src_offset + self.numel > max_elems:
src_offset = 0 if self.numel <= max_elems else max(0, max_elems-self.numel)
src = src_buf.view(self.numel, self.dtype, src_offset*self.dtype.itemsize).ensure_allocated()
BufferCopy(dest.nbytes, dest.device, src.device).copy(dest, src)
if wait:
Device[dest.device].synchronize()
return time.perf_counter() - st

class EncDec(Runner):
def __init__(self, encdec:UOp, total_sz:int, device:str):
self.shape, self.pos_var = encdec.arg[0], encdec.variables()[0].expr
Expand Down Expand Up @@ -121,11 +240,16 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
return ret

def lower_sink(ctx:list[Buffer|None], sink:UOp):
if isinstance(device:=ctx[0].device, str) and device.startswith(("DISK", "TINYFS")):
if (ret:=DiskStore.from_sink(sink, ctx)) is not None: return ret
return get_runner(device, sink)

# **************** lowering functions ****************

# NOTE: ctx is the buffers
si_lowerer = PatternMatcher([
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lower_sink),
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \
Expand Down Expand Up @@ -190,6 +314,13 @@ def run(self, _var_vals:dict[str, int]|None=None, wait=False, jit=False, do_upda
capturing: list = [] # put classes with an add method in here

def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
# DISK devices can't be reopened to larger mappings later, so pre-open the largest buffer per device up front.
max_disk_buf: dict[str, Buffer] = {}
for ei in schedule:
for b in ei.bufs:
if b is None or not isinstance(b.device, str) or not b.device.startswith(("DISK", "TINYFS")): continue
if (cur:=max_disk_buf.get(b.device)) is None or b.nbytes > cur.nbytes: max_disk_buf[b.device] = b
for b in max_disk_buf.values(): b.ensure_allocated()
while len(schedule):
ei = schedule.pop(0).lower()
if len(capturing) and CAPTURING: capturing[0].add(ei)
Expand Down
1 change: 1 addition & 0 deletions tinygrad/nn/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=No
t[0:8].assign(Tensor([len(j)], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8))
t[8:8+len(j)].assign(list(j.encode('utf-8')))
for k,v in safe_load(t).items(): v.assign(tensors[k])
t.realize()

# state dict

Expand Down
19 changes: 14 additions & 5 deletions tinygrad/schedule/rangeify.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,20 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
if (assign := x.src[0]).op is Ops.ASSIGN:
assign_target, assign_src = assign.src[0], assign.src[1]
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
is_bitcast_target = False
if assign_target.op is Ops.INDEX:
base_buf = assign_target.src[0]
elif assign_target.op is Ops.BITCAST and len(assign_target.src) == 1 and assign_target.src[0].op is Ops.INDEX:
base_buf = assign_target.src[0]
is_bitcast_target = True
else:
raise AssertionError(f"{assign_target.op} is not index")
while assign_src.op is Ops.NOOP: assign_src = assign_src.src[0]
# skip self-assign from same-device copy, otherwise create the store
# in assign, this is the buffer size, not the bufferize size
if assign_src is assign_target: ret = assign_target.src[0]
else: ret = assign_target.src[0].after(assign_target.replace(dtype=sdtype).store(assign_src).end(*rngs))
if assign_src is assign_target: ret = base_buf
else: ret = base_buf.after(assign_target.replace(dtype=sdtype).store(assign_src).end(*rngs))
if is_bitcast_target: return ret
for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg)
return ret

Expand Down Expand Up @@ -467,7 +475,8 @@ def split_store(x:UOp) -> UOp|None:
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))

kernel = ret.call(*lctx.map.values(), *lctx.vars.keys())
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
buf_devs = [x.device for x in kernel.src[1:] if x.op is not Ops.BIND]
if ret.op is Ops.SINK and not all_same(buf_devs) and not any(isinstance(dev, str) and dev.startswith(("DISK", "TINYFS")) for dev in buf_devs):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
return kernel

Expand Down Expand Up @@ -506,4 +515,4 @@ def get_rangeify(sink:UOp) -> UOp:
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
return tsink
return tsink
43 changes: 29 additions & 14 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
# *** all in scope Tensors are here. this gets relevant UOps ***

all_tensors: dict[weakref.ref[Tensor], None] = {}
_pending_assigns: dict[UOp, list[UOp]] = {} # buffer_uop -> [assign_uops in insertion order]
_pending_assigns: dict[UOp, list[UOp]] = {} # target root UOp -> [assign_uops in insertion order]

def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
with cpu_profile(TracingKey(name), "TINY"):
# get tensors in scope
Expand Down Expand Up @@ -273,20 +274,30 @@ def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
"""Triggers the computation needed to create these Tensor(s)."""
# side-realize pending assigns for buffers referenced by these tensors
if _pending_assigns:
def _realize_pending(buf):
for assign_uop in _pending_assigns.pop(buf, []):
def _realize_pending(target_uop:UOp):
pending = _pending_assigns.pop(target_uop, [])
# targets without concrete buffer identity (e.g. COPY to DISK from .to()) must be materialized first
if not target_uop.has_buffer_identity():
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(target_uop))
_apply_map_to_tensors(becomes_map, name="Apply Pending Target")
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
if becomes_map:
target_uop = becomes_map.get(target_uop, target_uop)
pending = [assign_uop.substitute(becomes_map) for assign_uop in pending]
for assigns in _pending_assigns.values():
for i in range(len(assigns)): assigns[i] = assigns[i].substitute(becomes_map)
for assign_uop in pending:
# recursively realize pending assigns that this assign's value depends on
for u in assign_uop.toposort():
if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u)
if u in _pending_assigns: _realize_pending(u)
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop))
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign")
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
# update remaining pending assigns so they reference realized buffers instead of stale lazy graphs
if becomes_map:
for assigns in _pending_assigns.values():
for i in range(len(assigns)): assigns[i] = assigns[i].substitute(becomes_map)
for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}:
if buf in _pending_assigns: _realize_pending(buf)
for target_uop in {u for t in (self,)+lst for u in t.uop.toposort() if u in _pending_assigns}: _realize_pending(target_uop)
if len(to_realize:=[x for x in (self,)+lst if not x.uop.has_buffer_identity()]):
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
return self
Expand All @@ -307,20 +318,22 @@ def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor:
# broadcast x (shape only, dtype must match)
if self.shape != x.shape: x = x._broadcast_to(self.shape)
if self.shape != x.shape: raise RuntimeError(f"assign shape mismatch {self.shape} != {x.shape}")
if not is_disk and self.device != x.device: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}")
if self.device != x.device:
if is_disk and (self.uop is self.uop.base or x.uop.op in {Ops.BITCAST, Ops.CAST}):
x = x.contiguous().to(self.device)
elif not is_disk: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}")
if self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}")
if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}")

# TODO: this is a hack for writing to DISK. remove with working assign
if is_disk:
self._buffer().copyin(x._data())
return self
result = self._apply_uop(UOp.assign, x)
# track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read
if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity():
target_uop = self.uop.base
if is_disk:
while target_uop.op is Ops.BITCAST: target_uop = target_uop.src[0].base
if self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity() and self.uop is not target_uop and target_uop.op in {Ops.BUFFER, Ops.COPY}:
# deduplicate: if the value is already a pending assign for this buffer (e.g. __iadd__ in __setitem__), remove it
if x.uop in _pending_assigns.get(buf_uop, []): _pending_assigns[buf_uop].remove(x.uop)
_pending_assigns.setdefault(buf_uop, []).append(result.uop)
if x.uop in _pending_assigns.get(target_uop, []): _pending_assigns[target_uop].remove(x.uop)
_pending_assigns.setdefault(target_uop, []).append(result.uop)
return self.replace(result)

def detach(self) -> Tensor:
Expand Down Expand Up @@ -1344,6 +1357,8 @@ def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
self.assign(self._getitem(indices, v))
elif is_disk or self.uop.is_realized: # basic setitem, self is realized. TODO: disk uop.base is a COPY and not realized
if is_disk and any(isinstance(i, slice) and i.step not in (None, 1) for i in idx):
raise RuntimeError("strided setitem is not supported for DISK tensors")
self[indices].assign(v)
else: # basic setitem, self is not realized
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
Expand Down
Loading