From f53c44ddda83f45d2868ff347da93c96321fb1cf Mon Sep 17 00:00:00 2001 From: Catherine Marks Date: Tue, 3 Mar 2026 12:57:04 -0800 Subject: [PATCH] fix(lang): Add nl.zeros with buffer parameter support nl.zeros was previously mapped to builtin.lang.zeros which was moved to builtin.tensor.zeros (a trace-time constant) in NKI-499, breaking the nl.zeros -> builtin.lang.zeros mapping. Customers using nl.zeros with buffer= (e.g. buffer=nl.sbuf) received 'unexpected keyword argument buffer' errors. Add builtin.lang.zeros to Lang.lean that allocates an on-device tensor (matching ndarray semantics) and emits a memset to zero-initialize it. Supports buffer=nl.sbuf (default), nl.psum, nl.hbm, nl.shared_hbm, and integer dtypes. Workaround for customers on affected releases: ZS = nl.ndarray(shape, dtype=dtype, buffer=buffer) nisa.memset(ZS, value=0.0) Fixes: V2117488695 --- KLR/Trace/Lang.lean | 21 +++++++++++++++++++++ interop/test/test_basic.py | 22 +++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/KLR/Trace/Lang.lean b/KLR/Trace/Lang.lean index 62f310f7..3f51642f 100644 --- a/KLR/Trace/Lang.lean +++ b/KLR/Trace/Lang.lean @@ -45,6 +45,27 @@ nki builtin.lang.ndarray Trace.addSharedBuffer tensor return .access (.simple tensor) +nki builtin.lang.zeros + (shape : Shape) + (dtype : Dtype) + (buffer : Option Memory := none) + (name : Option String := none) := do + let memory := buffer.getD .sbuf + let (parSize, freeSize) := Address.defaultSize shape dtype + let name <- tensorName name + let address := { name, memory, parSize, freeSize, parOffset := none, freeOffset := none : Address } + let tensor <- TensorName.make name dtype shape (some address) (<- flags.address_rotation) + if buffer == some .shared_hbm || buffer == some .hbm then + Trace.addSharedBuffer tensor + let access := Access.simple tensor + Trace.add_stmt $ .oper (.memSet { + dst := .abstract access, + value := if dtype.isInt then .int 0 else .float 0.0, + dtype := access.tensor.dtype, + engine := Engine.unassigned + }) none + return .access access + nki builtin.lang.par_dim (t : Term) := do warn "par_dim is deprecated" return t diff --git a/interop/test/test_basic.py b/interop/test/test_basic.py index f8be1eca..2164c1b9 100644 --- a/interop/test/test_basic.py +++ b/interop/test/test_basic.py @@ -145,6 +145,22 @@ def min_max_test(t): assert 2.0 == max([2.0, 1]) assert 2.0 == max([1, 2.0]) +def zeros_default_buffer(t): + z = nl.zeros((128, 512), dtype=nl.float16) + nisa.dma_copy(t[0:128, 0:512], z[:, :]) + +def zeros_sbuf(t): + z = nl.zeros((128, 512), dtype=nl.float16, buffer=nl.sbuf) + nisa.dma_copy(t[0:128, 0:512], z[:, :]) + +def zeros_psum(t): + z = nl.zeros((128, 512), dtype=nl.float32, buffer=nl.psum) + nisa.dma_copy(t[0:128, 0:512], z[:, :]) + +def zeros_int(t): + z = nl.zeros((128, 512), dtype=nl.int8, buffer=nl.sbuf) + nisa.dma_copy(t[0:128, 0:512], z[:, :]) + # test each function in turn @pytest.mark.parametrize("f", [ const_stmt, @@ -158,7 +174,11 @@ def min_max_test(t): ifs, loops, undefined_ok, - min_max_test + min_max_test, + zeros_default_buffer, + zeros_sbuf, + zeros_psum, + zeros_int, ]) def test_succeed(f): t = np.zeros((10,10,10), dtype=np.float32)