Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions tinygrad/engine/allocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,14 @@ def contiguous_mops_to_view(c:UOp):
if not hasattr(Device[c.device].allocator, "_offset"): return None

# see if this can be a view
size_offset = src.contiguous_view_offset()
if size_offset is None: return None
offset = src.contiguous_view_offset()
if offset is None: return None

# merge BUFFER_VIEWs
size, offset = size_offset
if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.arg[1], buf.src[0]

# NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (size, offset)).reshape(src.shape).contiguous(tag=c.tag)
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.size, offset)).reshape(src.shape).contiguous(tag=c.tag)


def transform_precompiled_call(c:UOp) -> UOp|None:
Expand Down
17 changes: 8 additions & 9 deletions tinygrad/uop/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,19 +657,19 @@ def buf_uop(self) -> UOp:
while len(s.src) and s.op not in {Ops.BUFFER, Ops.PARAM, Ops.BUFFERIZE, Ops.MSTACK}: s = s.src[0]
return s

def contiguous_view_offset(self) -> tuple[int, int]|None:
"""If movement ops on a BUFFER collapse to a contiguous range, return (size, offset) in elements. Otherwise None."""
def contiguous_view_offset(self) -> int|None:
"""If movement ops on a BUFFER collapse to a contiguous range, return `offset` in elements. Otherwise None."""
from tinygrad.schedule.rangeify import pm_mops
from tinygrad.uop.symbolic import symbolic
out = graph_rewrite(self._mop(Ops.RESHAPE, (self.size,)).index(UOp.range(self.size, 0)), pm_mops+symbolic, name="contiguous_view_offset")
if out.op is not Ops.INDEX: return None
if out.src[1].op is Ops.CONST and self.size == 1:
if not isinstance(out.src[1].arg, int): return None # masked/padded regions produce InvalidType
return (1, out.src[1].arg)
if out.src[1].op is Ops.RANGE: return (self.size, 0)
return out.src[1].arg
if out.src[1].op is Ops.RANGE: return 0
if out.src[1].op is Ops.ADD and out.src[1].src[0].op is Ops.RANGE and out.src[1].src[1].op is Ops.CONST:
if not isinstance(out.src[1].src[1].arg, int): return None # masked/padded regions produce InvalidType
return (self.size, out.src[1].src[1].arg)
return out.src[1].src[1].arg
return None

def has_buffer_identity(self):
Expand All @@ -683,12 +683,11 @@ def buffer(self) -> Buffer|MultiBuffer:
if self.op in {Ops.CONTIGUOUS, Ops.RESHAPE, Ops.DETACH, Ops.AFTER}: return self.src[0].buffer
# this buffer can process disk tensors and simple movement ops
if self is not self.base:
size_offset = self.contiguous_view_offset()
if size_offset is None: raise RuntimeError(f"cannot collapse movement ops on {self.base.op} to a contiguous view")
size, offset = size_offset
offset = self.contiguous_view_offset()
if offset is None: raise RuntimeError(f"cannot collapse movement ops on {self.base.op} to a contiguous view")
buf = self.base.buffer
assert isinstance(buf, Buffer), "must be a Buffer for movement ops"
return buf.view(size, self.dtype, offset*self.dtype.itemsize)
return buf.view(self.size, self.dtype, offset*self.dtype.itemsize)
if self.op is Ops.BITCAST:
buf = self.src[0].buffer
assert isinstance(buf, Buffer), "must be a Buffer for BITCAST"
Expand Down
Loading