Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions KLR/Trace/Lang.lean
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,27 @@ nki builtin.lang.ndarray
Trace.addSharedBuffer tensor
return .access (.simple tensor)

nki builtin.lang.zeros
(shape : Shape)
(dtype : Dtype)
(buffer : Option Memory := none)
(name : Option String := none) := do
let memory := buffer.getD .sbuf
let (parSize, freeSize) := Address.defaultSize shape dtype
let name <- tensorName name
let address := { name, memory, parSize, freeSize, parOffset := none, freeOffset := none : Address }
let tensor <- TensorName.make name dtype shape (some address) (<- flags.address_rotation)
if buffer == some .shared_hbm || buffer == some .hbm then
Trace.addSharedBuffer tensor
let access := Access.simple tensor
Trace.add_stmt $ .oper (.memSet {
dst := .abstract access,
value := if dtype.isInt then .int 0 else .float 0.0,
dtype := access.tensor.dtype,
engine := Engine.unassigned
}) none
return .access access

nki builtin.lang.par_dim (t : Term) := do
warn "par_dim is deprecated"
return t
Expand Down
22 changes: 21 additions & 1 deletion interop/test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,22 @@ def min_max_test(t):
assert 2.0 == max([2.0, 1])
assert 2.0 == max([1, 2.0])

def zeros_default_buffer(t):
z = nl.zeros((128, 512), dtype=nl.float16)
nisa.dma_copy(t[0:128, 0:512], z[:, :])

def zeros_sbuf(t):
z = nl.zeros((128, 512), dtype=nl.float16, buffer=nl.sbuf)
nisa.dma_copy(t[0:128, 0:512], z[:, :])

def zeros_psum(t):
z = nl.zeros((128, 512), dtype=nl.float32, buffer=nl.psum)
nisa.dma_copy(t[0:128, 0:512], z[:, :])

def zeros_int(t):
z = nl.zeros((128, 512), dtype=nl.int8, buffer=nl.sbuf)
nisa.dma_copy(t[0:128, 0:512], z[:, :])

# test each function in turn
@pytest.mark.parametrize("f", [
const_stmt,
Expand All @@ -158,7 +174,11 @@ def min_max_test(t):
ifs,
loops,
undefined_ok,
min_max_test
min_max_test,
zeros_default_buffer,
zeros_sbuf,
zeros_psum,
zeros_int,
])
def test_succeed(f):
t = np.zeros((10,10,10), dtype=np.float32)
Expand Down
Loading