diff --git a/KLR/Trace/Lang.lean b/KLR/Trace/Lang.lean index 62f310f7..3f51642f 100644 --- a/KLR/Trace/Lang.lean +++ b/KLR/Trace/Lang.lean @@ -45,6 +45,27 @@ nki builtin.lang.ndarray Trace.addSharedBuffer tensor return .access (.simple tensor) +nki builtin.lang.zeros + (shape : Shape) + (dtype : Dtype) + (buffer : Option Memory := none) + (name : Option String := none) := do + let memory := buffer.getD .sbuf + let (parSize, freeSize) := Address.defaultSize shape dtype + let name <- tensorName name + let address := { name, memory, parSize, freeSize, parOffset := none, freeOffset := none : Address } + let tensor <- TensorName.make name dtype shape (some address) (<- flags.address_rotation) + if buffer == some .shared_hbm || buffer == some .hbm then + Trace.addSharedBuffer tensor + let access := Access.simple tensor + Trace.add_stmt $ .oper (.memSet { + dst := .abstract access, + value := if dtype.isInt then .int 0 else .float 0.0, + dtype := access.tensor.dtype, + engine := Engine.unassigned + }) none + return .access access + nki builtin.lang.par_dim (t : Term) := do warn "par_dim is deprecated" return t diff --git a/interop/test/test_basic.py b/interop/test/test_basic.py index f8be1eca..2164c1b9 100644 --- a/interop/test/test_basic.py +++ b/interop/test/test_basic.py @@ -145,6 +145,22 @@ def min_max_test(t): assert 2.0 == max([2.0, 1]) assert 2.0 == max([1, 2.0]) +def zeros_default_buffer(t): + z = nl.zeros((128, 512), dtype=nl.float16) + nisa.dma_copy(t[0:128, 0:512], z[:, :]) + +def zeros_sbuf(t): + z = nl.zeros((128, 512), dtype=nl.float16, buffer=nl.sbuf) + nisa.dma_copy(t[0:128, 0:512], z[:, :]) + +def zeros_psum(t): + z = nl.zeros((128, 512), dtype=nl.float32, buffer=nl.psum) + nisa.dma_copy(t[0:128, 0:512], z[:, :]) + +def zeros_int(t): + z = nl.zeros((128, 512), dtype=nl.int8, buffer=nl.sbuf) + nisa.dma_copy(t[0:128, 0:512], z[:, :]) + # test each function in turn @pytest.mark.parametrize("f", [ const_stmt, @@ -158,7 +174,11 @@ def min_max_test(t): ifs, loops, undefined_ok, - min_max_test + min_max_test, + zeros_default_buffer, + zeros_sbuf, + zeros_psum, + zeros_int, ]) def test_succeed(f): t = np.zeros((10,10,10), dtype=np.float32)