Skip to content

Commit 171ff77

Browse files
committed
add workload obj, execution and mlir utils, and two workload examples
1 parent ba8aa78 commit 171ff77

File tree

6 files changed

+749
-0
lines changed

6 files changed

+749
-0
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""
2+
Workload example: Element-wise sum of two (M, N) float32 arrays on CPU.
3+
"""
4+
import numpy as np
5+
from mlir import ir
6+
from mlir.runtime.np_to_memref import get_ranked_memref_descriptor
7+
from mlir.dialects import func, linalg, bufferization
8+
from mlir.dialects import transform
9+
from functools import cached_property
10+
from lighthouse import Workload
11+
from lighthouse.utils.mlir import (
12+
apply_registered_pass,
13+
canonicalize,
14+
cse,
15+
match,
16+
)
17+
from lighthouse.utils.execution import (
18+
lower_payload,
19+
execute,
20+
benchmark,
21+
)
22+
23+
24+
class ElementwiseSum(Workload):
25+
"""
26+
Computes element-wise sum of (M, N) float32 arrays on CPU.
27+
28+
We can construct the input arrays and compute the reference solution in
29+
Python with Numpy.
30+
"""
31+
32+
def __init__(self, M, N):
33+
self.M = M
34+
self.N = N
35+
self.dtype = np.float32
36+
self.context = ir.Context()
37+
self.location = ir.Location.unknown(context=self.context)
38+
39+
@cached_property
40+
def _input_arrays(self):
41+
print(" * Generating input arrays...")
42+
np.random.seed(2)
43+
A = np.random.rand(self.M, self.N).astype(self.dtype)
44+
B = np.random.rand(self.M, self.N).astype(self.dtype)
45+
C = np.zeros((self.M, self.N), dtype=self.dtype)
46+
return [A, B, C]
47+
48+
@cached_property
49+
def _reference_solution(self):
50+
print(" * Computing reference solution...")
51+
A, B, _ = self._input_arrays
52+
return A + B
53+
54+
def get_input_arrays(self, execution_engine):
55+
return [
56+
get_ranked_memref_descriptor(a) for a in self._input_arrays
57+
]
58+
59+
def verify(self, execution_engine, verbose: int = 0) -> bool:
60+
C = self._input_arrays[2]
61+
C_ref = self._reference_solution
62+
if verbose > 1:
63+
print("Reference solution:")
64+
print(C_ref)
65+
print("Computed solution:")
66+
print(C)
67+
success = np.allclose(C, C_ref)
68+
if verbose:
69+
if success:
70+
print("PASSED")
71+
else:
72+
print("FAILED Result mismatch!")
73+
return success
74+
75+
def requirements(self):
76+
return []
77+
78+
def get_complexity(self):
79+
nbytes = np.dtype(self.dtype).itemsize
80+
flop_count = self.M * self.N # one addition per element
81+
memory_reads = 2 * self.M * self.N * nbytes # read A and B
82+
memory_writes = self.M * self.N * nbytes # write C
83+
return (flop_count, memory_reads, memory_writes)
84+
85+
def payload_module(self):
86+
with self.context, self.location:
87+
float32_t = ir.F32Type.get()
88+
shape = (self.M, self.N)
89+
tensor_t = ir.RankedTensorType.get(shape, float32_t)
90+
memref_t = ir.MemRefType.get(shape, float32_t)
91+
mod = ir.Module.create()
92+
with ir.InsertionPoint(mod.body):
93+
args = [memref_t, memref_t, memref_t]
94+
f = func.FuncOp(self.payload_function_name, (tuple(args), ()))
95+
f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
96+
with ir.InsertionPoint(f.add_entry_block()):
97+
A = f.arguments[0]
98+
B = f.arguments[1]
99+
C = f.arguments[2]
100+
a_tensor = bufferization.ToTensorOp(tensor_t, A, restrict=True)
101+
b_tensor = bufferization.ToTensorOp(tensor_t, B, restrict=True)
102+
c_tensor = bufferization.ToTensorOp(
103+
tensor_t, C, restrict=True, writable=True
104+
)
105+
add = linalg.add(a_tensor, b_tensor, outs=[c_tensor])
106+
bufferization.MaterializeInDestinationOp(
107+
None, add, C, restrict=True, writable=True
108+
)
109+
func.ReturnOp(())
110+
return mod
111+
112+
def schedule_module(self, dump_kernel=None, parameters=None):
113+
with self.context, self.location:
114+
schedule_module = ir.Module.create()
115+
schedule_module.operation.attributes[
116+
"transform.with_named_sequence"] = (ir.UnitAttr.get())
117+
with ir.InsertionPoint(schedule_module.body):
118+
named_sequence = transform.NamedSequenceOp(
119+
"__transform_main",
120+
[transform.AnyOpType.get()],
121+
[],
122+
arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}],
123+
)
124+
with ir.InsertionPoint(named_sequence.body):
125+
anytype = transform.AnyOpType.get()
126+
func = match(named_sequence.bodyTarget, ops={"func.func"})
127+
mod = transform.get_parent_op(
128+
anytype,
129+
func,
130+
op_name="builtin.module",
131+
deduplicate=True,
132+
)
133+
mod = apply_registered_pass(mod, "one-shot-bufferize")
134+
mod = apply_registered_pass(mod, "convert-linalg-to-loops")
135+
cse(mod)
136+
canonicalize(mod)
137+
138+
if dump_kernel == "bufferized":
139+
transform.YieldOp()
140+
return schedule_module
141+
142+
mod = apply_registered_pass(mod, "convert-scf-to-cf")
143+
mod = apply_registered_pass(mod, "finalize-memref-to-llvm")
144+
mod = apply_registered_pass(mod, "convert-cf-to-llvm")
145+
mod = apply_registered_pass(mod, "convert-arith-to-llvm")
146+
mod = apply_registered_pass(mod, "convert-func-to-llvm")
147+
mod = apply_registered_pass(mod,
148+
"reconcile-unrealized-casts")
149+
transform.YieldOp()
150+
151+
return schedule_module
152+
153+
154+
if __name__ == "__main__":
155+
wload = ElementwiseSum(400, 400)
156+
157+
print(" Dump kernel ".center(60, "-"))
158+
lower_payload(wload, dump_kernel="bufferized", dump_schedule=True)
159+
160+
print(" Execute 1 ".center(60, "-"))
161+
execute(wload, verbose=2)
162+
163+
print(" Execute 2 ".center(60, "-"))
164+
execute(wload, verbose=1)
165+
166+
print(" Benchmark ".center(60, "-"))
167+
times = benchmark(wload)
168+
times *= 1e6 # convert to microseconds
169+
# compute statistics
170+
mean = np.mean(times)
171+
min = np.min(times)
172+
max = np.max(times)
173+
std = np.std(times)
174+
print(f"Timings (us): "
175+
f"mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}")
176+
flop_count = wload.get_complexity()[0]
177+
gflops = flop_count / (mean * 1e-6) / 1e9
178+
print(f"Throughput: {gflops:.2f} GFLOPS")
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""
2+
Workload example: Element-wise sum of two (M, N) float32 arrays on CPU.
3+
4+
In this example, allocation and deallocation of input arrays is done in MLIR.
5+
"""
6+
import numpy as np
7+
from mlir import ir
8+
from mlir.runtime.np_to_memref import (
9+
ranked_memref_to_numpy,
10+
make_nd_memref_descriptor,
11+
as_ctype,
12+
)
13+
from mlir.dialects import func, linalg, bufferization, arith, memref
14+
from mlir.dialects import transform
15+
import ctypes
16+
from contextlib import contextmanager
17+
from lighthouse import Workload
18+
from lighthouse.utils.mlir import (
19+
apply_registered_pass,
20+
canonicalize,
21+
cse,
22+
match,
23+
)
24+
from lighthouse.utils import get_packed_arg
25+
from lighthouse.utils.execution import (
26+
lower_payload,
27+
execute,
28+
benchmark,
29+
)
30+
from example import ElementwiseSum
31+
32+
33+
def emit_host_alloc(mod, suffix, element_type, rank=2):
34+
dyn = ir.ShapedType.get_dynamic_size()
35+
memref_dyn_t = ir.MemRefType.get(rank*(dyn,), element_type)
36+
index_t = ir.IndexType.get()
37+
i32_t = ir.IntegerType.get_signless(32)
38+
with ir.InsertionPoint(mod.body):
39+
f = func.FuncOp(
40+
"host_alloc_" + suffix, (rank*(i32_t,), (memref_dyn_t,))
41+
)
42+
f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
43+
with ir.InsertionPoint(f.add_entry_block()):
44+
dims = [
45+
arith.IndexCastOp(index_t, a) for a in list(f.arguments)
46+
]
47+
alloc = memref.alloc(memref_dyn_t, dims, [])
48+
func.ReturnOp((alloc,))
49+
50+
51+
def emit_host_dealloc(mod, suffix, element_type, rank=2):
52+
dyn = ir.ShapedType.get_dynamic_size()
53+
memref_dyn_t = ir.MemRefType.get(rank*(dyn,), element_type)
54+
with ir.InsertionPoint(mod.body):
55+
f = func.FuncOp("host_dealloc_" + suffix, ((memref_dyn_t,), ()))
56+
f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
57+
with ir.InsertionPoint(f.add_entry_block()):
58+
memref.dealloc(f.arguments[0])
59+
func.ReturnOp(())
60+
61+
62+
def emit_fill_constant(mod, suffix, value, element_type, rank=2):
63+
dyn = ir.ShapedType.get_dynamic_size()
64+
memref_dyn_t = ir.MemRefType.get(rank*(dyn,), element_type)
65+
with ir.InsertionPoint(mod.body):
66+
f = func.FuncOp("host_fill_constant_" + suffix, ((memref_dyn_t,), ()))
67+
f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
68+
with ir.InsertionPoint(f.add_entry_block()):
69+
const = arith.constant(element_type, value)
70+
linalg.fill(const, outs=[f.arguments[0]])
71+
func.ReturnOp(())
72+
73+
74+
def emit_fill_random(mod, suffix, element_type, min=0.0, max=1.0, seed=2):
75+
rank = 2
76+
dyn = ir.ShapedType.get_dynamic_size()
77+
memref_dyn_t = ir.MemRefType.get(rank*(dyn,), element_type)
78+
i32_t = ir.IntegerType.get_signless(32)
79+
f64_t = ir.F64Type.get()
80+
with ir.InsertionPoint(mod.body):
81+
f = func.FuncOp("host_fill_random_" + suffix, ((memref_dyn_t,), ()))
82+
f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
83+
with ir.InsertionPoint(f.add_entry_block()):
84+
min_cst = arith.constant(f64_t, min)
85+
max_cst = arith.constant(f64_t, max)
86+
seed_cst = arith.constant(i32_t, seed)
87+
linalg.fill_rng_2d(min_cst, max_cst, seed_cst, outs=[f.arguments[0]])
88+
func.ReturnOp(())
89+
90+
91+
class ElementwiseSumMLIRAlloc(ElementwiseSum):
92+
"""
93+
Computes element-wise sum of (M, N) float32 arrays on CPU.
94+
95+
Extends ElementwiseSum by allocating input arrays in MLIR.
96+
"""
97+
98+
def __init__(self, M, N):
99+
super().__init__(M, N)
100+
# keep track of allocated memrefs
101+
self.memrefs = {}
102+
103+
def _allocate_array(self, name, execution_engine):
104+
if name in self.memrefs:
105+
return self.memrefs[name]
106+
alloc_func = execution_engine.lookup("host_alloc_f32")
107+
shape = (self.M, self.N)
108+
mref = make_nd_memref_descriptor(len(shape), as_ctype(self.dtype))()
109+
ptr_mref = ctypes.pointer(ctypes.pointer(mref))
110+
ptr_dims = [ctypes.pointer(ctypes.c_int32(d)) for d in shape]
111+
alloc_func(get_packed_arg([ptr_mref, *ptr_dims]))
112+
self.memrefs[name] = mref
113+
return mref
114+
115+
def _allocate_inputs(self, execution_engine):
116+
self._allocate_array("A", execution_engine)
117+
self._allocate_array("B", execution_engine)
118+
self._allocate_array("C", execution_engine)
119+
120+
def _deallocate_all(self, execution_engine):
121+
for mref in self.memrefs.values():
122+
dealloc_func = execution_engine.lookup("host_dealloc_f32")
123+
ptr_mref = ctypes.pointer(ctypes.pointer(mref))
124+
dealloc_func(get_packed_arg([ptr_mref]))
125+
self.memrefs = {}
126+
127+
@contextmanager
128+
def allocate(self, execution_engine):
129+
try:
130+
self._allocate_inputs(execution_engine)
131+
yield None
132+
finally:
133+
self._deallocate_all(execution_engine)
134+
135+
def get_input_arrays(self, execution_engine):
136+
A = self._allocate_array("A", execution_engine)
137+
B = self._allocate_array("B", execution_engine)
138+
C = self._allocate_array("C", execution_engine)
139+
140+
# initialize with MLIR
141+
fill_zero_func = execution_engine.lookup("host_fill_constant_zero_f32")
142+
fill_random_func = execution_engine.lookup("host_fill_random_f32")
143+
fill_zero_func(get_packed_arg([ctypes.pointer(ctypes.pointer(C))]))
144+
fill_random_func(get_packed_arg([ctypes.pointer(ctypes.pointer(A))]))
145+
fill_random_func(get_packed_arg([ctypes.pointer(ctypes.pointer(B))]))
146+
147+
return [A, B, C]
148+
149+
def verify(self, execution_engine, verbose: int = 0) -> bool:
150+
# compute reference solution with numpy
151+
A = ranked_memref_to_numpy([self.memrefs["A"]])
152+
B = ranked_memref_to_numpy([self.memrefs["B"]])
153+
C = ranked_memref_to_numpy([self.memrefs["C"]])
154+
C_ref = A + B
155+
if verbose > 1:
156+
print("Reference solution:")
157+
print(C_ref)
158+
print("Computed solution:")
159+
print(C)
160+
success = np.allclose(C, C_ref)
161+
162+
# Alternatively we could have done the verification in MLIR by emitting
163+
# a check function.
164+
# Here we just call the payload function again.
165+
# self._allocate_array("C_ref", execution_engine)
166+
# func = execution_engine.lookup("payload")
167+
# func(get_packed_arg([
168+
# ctypes.pointer(ctypes.pointer(self.memrefs["A"])),
169+
# ctypes.pointer(ctypes.pointer(self.memrefs["B"])),
170+
# ctypes.pointer(ctypes.pointer(self.memrefs["C_ref"])),
171+
# ]))
172+
# Check correctness with numpy.
173+
# C = ranked_memref_to_numpy([self.memrefs["C"]])
174+
# C_ref = ranked_memref_to_numpy([self.memrefs["C_ref"]])
175+
# success = np.allclose(C, C_ref)
176+
177+
if verbose:
178+
if success:
179+
print("PASSED")
180+
else:
181+
print("FAILED Result mismatch!")
182+
return success
183+
184+
def payload_module(self):
185+
mod = super().payload_module()
186+
# extend the payload module with de/alloc/fill functions
187+
with self.context, self.location:
188+
float32_t = ir.F32Type.get()
189+
emit_host_alloc(mod, "f32", float32_t)
190+
emit_host_dealloc(mod, "f32", float32_t)
191+
emit_fill_constant(mod, "zero_f32", 0.0, float32_t)
192+
emit_fill_random(mod, "f32", float32_t, min=-1.0, max=1.0)
193+
return mod
194+
195+
196+
if __name__ == "__main__":
197+
wload = ElementwiseSumMLIRAlloc(400, 400)
198+
199+
print(" Dump kernel ".center(60, "-"))
200+
lower_payload(wload, dump_kernel="bufferized", dump_schedule=False)
201+
202+
print(" Execute ".center(60, "-"))
203+
execute(wload, verbose=2)
204+
205+
print(" Benchmark ".center(60, "-"))
206+
times = benchmark(wload)
207+
times *= 1e6 # convert to microseconds
208+
# compute statistics
209+
mean = np.mean(times)
210+
min = np.min(times)
211+
max = np.max(times)
212+
std = np.std(times)
213+
print(f"Timings (us): "
214+
f"mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}")
215+
flop_count = wload.get_complexity()[0]
216+
gflops = flop_count / (mean * 1e-6) / 1e9
217+
print(f"Throughput: {gflops:.2f} GFLOPS")

0 commit comments

Comments
 (0)