diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index 1c42d056cf55d..20f20f4f328ff 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -88,15 +88,14 @@ def contiguous_mops_to_view(c:UOp): if not hasattr(Device[c.device].allocator, "_offset"): return None # see if this can be a view - size_offset = src.contiguous_view_offset() - if size_offset is None: return None + offset = src.contiguous_view_offset() + if offset is None: return None # merge BUFFER_VIEWs - size, offset = size_offset if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.arg[1], buf.src[0] # NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity - return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (size, offset)).reshape(src.shape).contiguous(tag=c.tag) + return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.size, offset)).reshape(src.shape).contiguous(tag=c.tag) def transform_precompiled_call(c:UOp) -> UOp|None: diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index a12e68994c82f..e95d3e115756f 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -657,19 +657,19 @@ def buf_uop(self) -> UOp: while len(s.src) and s.op not in {Ops.BUFFER, Ops.PARAM, Ops.BUFFERIZE, Ops.MSTACK}: s = s.src[0] return s - def contiguous_view_offset(self) -> tuple[int, int]|None: - """If movement ops on a BUFFER collapse to a contiguous range, return (size, offset) in elements. Otherwise None.""" + def contiguous_view_offset(self) -> int|None: + """If movement ops on a BUFFER collapse to a contiguous range, return `offset` in elements. Otherwise None.""" from tinygrad.schedule.rangeify import pm_mops from tinygrad.uop.symbolic import symbolic out = graph_rewrite(self._mop(Ops.RESHAPE, (self.size,)).index(UOp.range(self.size, 0)), pm_mops+symbolic, name="contiguous_view_offset") if out.op is not Ops.INDEX: return None if out.src[1].op is Ops.CONST and self.size == 1: if not isinstance(out.src[1].arg, int): return None # masked/padded regions produce InvalidType - return (1, out.src[1].arg) - if out.src[1].op is Ops.RANGE: return (self.size, 0) + return out.src[1].arg + if out.src[1].op is Ops.RANGE: return 0 if out.src[1].op is Ops.ADD and out.src[1].src[0].op is Ops.RANGE and out.src[1].src[1].op is Ops.CONST: if not isinstance(out.src[1].src[1].arg, int): return None # masked/padded regions produce InvalidType - return (self.size, out.src[1].src[1].arg) + return out.src[1].src[1].arg return None def has_buffer_identity(self): @@ -683,12 +683,11 @@ def buffer(self) -> Buffer|MultiBuffer: if self.op in {Ops.CONTIGUOUS, Ops.RESHAPE, Ops.DETACH, Ops.AFTER}: return self.src[0].buffer # this buffer can process disk tensors and simple movement ops if self is not self.base: - size_offset = self.contiguous_view_offset() - if size_offset is None: raise RuntimeError(f"cannot collapse movement ops on {self.base.op} to a contiguous view") - size, offset = size_offset + offset = self.contiguous_view_offset() + if offset is None: raise RuntimeError(f"cannot collapse movement ops on {self.base.op} to a contiguous view") buf = self.base.buffer assert isinstance(buf, Buffer), "must be a Buffer for movement ops" - return buf.view(size, self.dtype, offset*self.dtype.itemsize) + return buf.view(self.size, self.dtype, offset*self.dtype.itemsize) if self.op is Ops.BITCAST: buf = self.src[0].buffer assert isinstance(buf, Buffer), "must be a Buffer for BITCAST"