Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/exo/core/LoopIR_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@
from .internal_cursors import Node, Gap, Block, Cursor, InvalidCursorError, GapType
from .prelude import *


def _format_code(code):
# See the following file for customization options:
# https://github.com/google/yapf/blob/main/yapf/yapflib/style.py
code, _ = FormatCode(
code,
style_config={
"based_on_style": "pep8",
"column_limit": 160,
},
)
return code.rstrip("\n")


# --------------------------------------------------------------------------- #
# --------------------------------------------------------------------------- #
# Notes on Layout Schemes...
Expand Down Expand Up @@ -110,10 +124,7 @@ def str(self):
assert len(self._lines) == 1
return self._lines[0]

fmtstr, linted = FormatCode("\n".join(self._lines))
if isinstance(self._node, LoopIR.proc):
assert linted, "generated unlinted code..."
return fmtstr
return _format_code("\n".join(self._lines))

def push(self, only=None):
if only is None:
Expand Down Expand Up @@ -321,10 +332,6 @@ def ptype(self, t):
# LoopIR Pretty Printing


def _format_code(code):
return FormatCode(code)[0].rstrip("\n")


@extclass(LoopIR.proc)
def __str__(self):
return _format_code("\n".join(_print_proc(self, PrintEnv(), "")))
Expand Down
3 changes: 1 addition & 2 deletions tests/golden/test_config/test_ld.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ def config_ld_i8(scale: f32 @ DRAM, src_stride: stride @ DRAM):
#
ConfigLoad.scale = scale
ConfigLoad.src_stride = src_stride
def ld_i8(n: size, m: size, scale: f32 @ DRAM, src: i8[n, m] @ DRAM,
dst: i8[n, 16] @ GEMM_SCRATCH):
def ld_i8(n: size, m: size, scale: f32 @ DRAM, src: i8[n, m] @ DRAM, dst: i8[n, 16] @ GEMM_SCRATCH):
assert n <= 16
assert m <= 16
assert stride(src, 1) == 1
Expand Down
3 changes: 1 addition & 2 deletions tests/golden/test_cursors/test_basic_forwarding2.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
def filter1D(ow: size, kw: size, x: f32[ow + kw - 1] @ DRAM, y: f32[ow] @ DRAM,
w: f32[kw] @ DRAM):
def filter1D(ow: size, kw: size, x: f32[ow + kw - 1] @ DRAM, y: f32[ow] @ DRAM, w: f32[kw] @ DRAM):
for outXo in seq(0, ow / 4):
sum: f32[4] @ DRAM # <-- NODE
for outXi in seq(0, 4):
Expand Down
3 changes: 1 addition & 2 deletions tests/golden/test_cursors/test_basic_forwarding3.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
def filter1D(ow: size, kw: size, x: f32[ow + kw - 1] @ DRAM, y: f32[ow] @ DRAM,
w: f32[kw] @ DRAM):
def filter1D(ow: size, kw: size, x: f32[ow + kw - 1] @ DRAM, y: f32[ow] @ DRAM, w: f32[kw] @ DRAM):
for outXo in seq(0, ow / 4):
sum: f32[4] @ DRAM # <-- NODE
for outXi in seq(0, 4):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@ def foo():
producer[1 + y, x] = 1.0
for x in seq(0, 10):
producer[1 + y, 1 + x] = 1.0
consumer[y, x] = producer[y, x] + producer[y, 1 + x] + producer[
1 + y, x] + producer[1 + y, 1 + x]
consumer[y, x] = producer[y, x] + producer[y, 1 + x] + producer[1 + y, x] + producer[1 + y, 1 + x]
17 changes: 5 additions & 12 deletions tests/golden/test_halide_ops/test_schedule_blur1d.txt
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
def blur1d_compute_at_store_root(n: size, consumer: i8[n] @ DRAM,
inp: i8[n + 6] @ DRAM):
def blur1d_compute_at_store_root(n: size, consumer: i8[n] @ DRAM, inp: i8[n + 6] @ DRAM):
producer: i8[1 + n] @ DRAM
for i in seq(0, n):
for ii in seq(0, 2):
producer[i + ii] = (inp[i + ii] + inp[1 + i + ii] +
inp[2 + i + ii] + inp[3 + i + ii] +
inp[4 + i + ii] + inp[5 + i + ii]) / 6.0
producer[i + ii] = (inp[i + ii] + inp[1 + i + ii] + inp[2 + i + ii] + inp[3 + i + ii] + inp[4 + i + ii] + inp[5 + i + ii]) / 6.0
consumer[i] = (producer[i] + producer[1 + i]) / 2.0

def blur1d_compute_at(n: size, consumer: i8[n] @ DRAM, inp: i8[n + 6] @ DRAM):
for i in seq(0, n):
producer: i8[2] @ DRAM
for ii in seq(0, 2):
producer[ii] = (inp[i + ii] + inp[1 + i + ii] + inp[2 + i + ii] +
inp[3 + i + ii] + inp[4 + i + ii] +
inp[5 + i + ii]) / 6.0
producer[ii] = (inp[i + ii] + inp[1 + i + ii] + inp[2 + i + ii] + inp[3 + i + ii] + inp[4 + i + ii] + inp[5 + i + ii]) / 6.0
consumer[i] = (producer[0] + producer[1]) / 2.0

def blur1d_inline(n: size, consumer: i8[n] @ DRAM, inp: i8[n + 6] @ DRAM):
for i in seq(0, n):
consumer[i] = ((inp[i] + inp[1 + i] + inp[2 + i] + inp[3 + i] +
inp[4 + i] + inp[5 + i]) / 6.0 +
(inp[1 + i] + inp[2 + i] + inp[3 + i] + inp[4 + i] +
inp[5 + i] + inp[6 + i]) / 6.0) / 2.0
consumer[i] = ((inp[i] + inp[1 + i] + inp[2 + i] + inp[3 + i] + inp[4 + i] + inp[5 + i]) / 6.0 +
(inp[1 + i] + inp[2 + i] + inp[3 + i] + inp[4 + i] + inp[5 + i] + inp[6 + i]) / 6.0) / 2.0
32 changes: 10 additions & 22 deletions tests/golden/test_halide_ops/test_schedule_blur2d.txt
Original file line number Diff line number Diff line change
@@ -1,57 +1,45 @@
def blur2d_compute_at_i_store_root(n: size, consumer: i8[n, n] @ DRAM,
sin: i8[n + 1, n + 1] @ DRAM):
def blur2d_compute_at_i_store_root(n: size, consumer: i8[n, n] @ DRAM, sin: i8[n + 1, n + 1] @ DRAM):
assert n % 4 == 0
producer: i8[1 + n, 1 + n] @ DRAM
for i in seq(0, n):
for ii in seq(0, 2):
for j in seq(0, 1 + n):
producer[i + ii, j] = sin[i + ii, j]
for j in seq(0, n):
consumer[i,
j] = (producer[i, j] + producer[i, 1 + j] +
producer[1 + i, j] + producer[1 + i, 1 + j]) / 4.0
consumer[i, j] = (producer[i, j] + producer[i, 1 + j] + producer[1 + i, j] + producer[1 + i, 1 + j]) / 4.0

def blur2d_compute_at_j_store_root(n: size, consumer: i8[n, n] @ DRAM,
sin: i8[n + 1, n + 1] @ DRAM):
def blur2d_compute_at_j_store_root(n: size, consumer: i8[n, n] @ DRAM, sin: i8[n + 1, n + 1] @ DRAM):
assert n % 4 == 0
producer: i8[1 + n, 1 + n] @ DRAM
for i in seq(0, n):
for j in seq(0, n):
for ji in seq(0, 2):
for ii in seq(0, 2):
producer[i + ii, j + ji] = sin[i + ii, j + ji]
consumer[i,
j] = (producer[i, j] + producer[i, 1 + j] +
producer[1 + i, j] + producer[1 + i, 1 + j]) / 4.0
consumer[i, j] = (producer[i, j] + producer[i, 1 + j] + producer[1 + i, j] + producer[1 + i, 1 + j]) / 4.0

def blur2d_compute_at_i(n: size, consumer: i8[n, n] @ DRAM,
sin: i8[n + 1, n + 1] @ DRAM):
def blur2d_compute_at_i(n: size, consumer: i8[n, n] @ DRAM, sin: i8[n + 1, n + 1] @ DRAM):
assert n % 4 == 0
for i in seq(0, n):
producer: i8[2, 1 + n] @ DRAM
for ii in seq(0, 2):
for j in seq(0, 1 + n):
producer[ii, j] = sin[i + ii, j]
for j in seq(0, n):
consumer[i, j] = (producer[0, j] + producer[0, 1 + j] +
producer[1, j] + producer[1, 1 + j]) / 4.0
consumer[i, j] = (producer[0, j] + producer[0, 1 + j] + producer[1, j] + producer[1, 1 + j]) / 4.0

def blur2d_compute_at_j_store_at_i(n: size, consumer: i8[n, n] @ DRAM,
sin: i8[n + 1, n + 1] @ DRAM):
def blur2d_compute_at_j_store_at_i(n: size, consumer: i8[n, n] @ DRAM, sin: i8[n + 1, n + 1] @ DRAM):
assert n % 4 == 0
for i in seq(0, n):
producer: i8[2, 1 + n] @ DRAM
for j in seq(0, n):
for ji in seq(0, 2):
for ii in seq(0, 2):
producer[ii, j + ji] = sin[i + ii, j + ji]
consumer[i, j] = (producer[0, j] + producer[0, 1 + j] +
producer[1, j] + producer[1, 1 + j]) / 4.0
consumer[i, j] = (producer[0, j] + producer[0, 1 + j] + producer[1, j] + producer[1, 1 + j]) / 4.0

def blur2d_inline(n: size, consumer: i8[n, n] @ DRAM,
sin: i8[n + 1, n + 1] @ DRAM):
def blur2d_inline(n: size, consumer: i8[n, n] @ DRAM, sin: i8[n + 1, n + 1] @ DRAM):
assert n % 4 == 0
for i in seq(0, n):
for j in seq(0, n):
consumer[i, j] = (sin[i + 0, j + 0] + sin[i + 0, j + 1] +
sin[i + 1, j + 0] + sin[i + 1, j + 1]) / 4.0
consumer[i, j] = (sin[i + 0, j + 0] + sin[i + 0, j + 1] + sin[i + 1, j + 0] + sin[i + 1, j + 1]) / 4.0
72 changes: 21 additions & 51 deletions tests/golden/test_halide_ops/test_schedule_tiled_blur2d.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
def blur2d_tiled(n: size, consumer: i8[n, n] @ DRAM,
sin: i8[n + 1, n + 1] @ DRAM):
def blur2d_tiled(n: size, consumer: i8[n, n] @ DRAM, sin: i8[n + 1, n + 1] @ DRAM):
assert n % 4 == 0
producer: i8[n + 1, n + 1] @ DRAM
for i in seq(0, n + 1):
Expand All @@ -9,15 +8,10 @@ def blur2d_tiled(n: size, consumer: i8[n, n] @ DRAM,
for j in seq(0, n / 4):
for ii in seq(0, 4):
for ji in seq(0, 4):
consumer[
4 * i + ii, 4 * j +
ji] = (producer[4 * i + ii, 4 * j + ji] +
producer[4 * i + ii, 4 * j + ji + 1] +
producer[4 * i + ii + 1, 4 * j + ji] +
producer[4 * i + ii + 1, 4 * j + ji + 1]) / 4.0
consumer[4 * i + ii, 4 * j + ji] = (producer[4 * i + ii, 4 * j + ji] + producer[4 * i + ii, 4 * j + ji + 1] +
producer[4 * i + ii + 1, 4 * j + ji] + producer[4 * i + ii + 1, 4 * j + ji + 1]) / 4.0

def blur2d_tiled_compute_at_i(n: size, consumer: i8[n, n] @ DRAM,
sin: i8[n + 1, n + 1] @ DRAM):
def blur2d_tiled_compute_at_i(n: size, consumer: i8[n, n] @ DRAM, sin: i8[n + 1, n + 1] @ DRAM):
assert n % 4 == 0
producer: i8[1 + n, 1 + n] @ DRAM
for i in seq(0, n / 4):
Expand All @@ -27,51 +21,36 @@ def blur2d_tiled_compute_at_i(n: size, consumer: i8[n, n] @ DRAM,
for j in seq(0, n / 4):
for ii in seq(0, 4):
for ji in seq(0, 4):
consumer[ii + 4 * i, ji + 4 * j] = (
producer[ii + 4 * i, ji + 4 * j] +
producer[ii + 4 * i, 1 + ji + 4 * j] +
producer[1 + ii + 4 * i, ji + 4 * j] +
producer[1 + ii + 4 * i, 1 + ji + 4 * j]) / 4.0
consumer[ii + 4 * i, ji + 4 * j] = (producer[ii + 4 * i, ji + 4 * j] + producer[ii + 4 * i, 1 + ji + 4 * j] +
producer[1 + ii + 4 * i, ji + 4 * j] + producer[1 + ii + 4 * i, 1 + ji + 4 * j]) / 4.0

def blur2d_tiled_compute_at_j(n: size, consumer: i8[n, n] @ DRAM,
sin: i8[n + 1, n + 1] @ DRAM):
def blur2d_tiled_compute_at_j(n: size, consumer: i8[n, n] @ DRAM, sin: i8[n + 1, n + 1] @ DRAM):
assert n % 4 == 0
producer: i8[1 + n, 1 + n] @ DRAM
for i in seq(0, n / 4):
for j in seq(0, n / 4):
for ji in seq(0, 5):
for ii in seq(0, 5):
producer[ii + 4 * i, ji + 4 * j] = sin[ii + 4 * i,
ji + 4 * j]
producer[ii + 4 * i, ji + 4 * j] = sin[ii + 4 * i, ji + 4 * j]
for ii in seq(0, 4):
for ji in seq(0, 4):
consumer[ii + 4 * i, ji + 4 * j] = (
producer[ii + 4 * i, ji + 4 * j] +
producer[ii + 4 * i, 1 + ji + 4 * j] +
producer[1 + ii + 4 * i, ji + 4 * j] +
producer[1 + ii + 4 * i, 1 + ji + 4 * j]) / 4.0
consumer[ii + 4 * i, ji + 4 * j] = (producer[ii + 4 * i, ji + 4 * j] + producer[ii + 4 * i, 1 + ji + 4 * j] +
producer[1 + ii + 4 * i, ji + 4 * j] + producer[1 + ii + 4 * i, 1 + ji + 4 * j]) / 4.0

def blur2d_tiled_compute_at_ii(n: size, consumer: i8[n, n] @ DRAM,
sin: i8[n + 1, n + 1] @ DRAM):
def blur2d_tiled_compute_at_ii(n: size, consumer: i8[n, n] @ DRAM, sin: i8[n + 1, n + 1] @ DRAM):
assert n % 4 == 0
producer: i8[1 + n, 1 + n] @ DRAM
for i in seq(0, n / 4):
for j in seq(0, n / 4):
for ii in seq(0, 4):
for iii in seq(0, 2):
for ji in seq(0, 5):
producer[ii + iii + 4 * i,
ji + 4 * j] = sin[ii + iii + 4 * i,
ji + 4 * j]
producer[ii + iii + 4 * i, ji + 4 * j] = sin[ii + iii + 4 * i, ji + 4 * j]
for ji in seq(0, 4):
consumer[ii + 4 * i, ji + 4 * j] = (
producer[ii + 4 * i, ji + 4 * j] +
producer[ii + 4 * i, 1 + ji + 4 * j] +
producer[1 + ii + 4 * i, ji + 4 * j] +
producer[1 + ii + 4 * i, 1 + ji + 4 * j]) / 4.0
consumer[ii + 4 * i, ji + 4 * j] = (producer[ii + 4 * i, ji + 4 * j] + producer[ii + 4 * i, 1 + ji + 4 * j] +
producer[1 + ii + 4 * i, ji + 4 * j] + producer[1 + ii + 4 * i, 1 + ji + 4 * j]) / 4.0

def blur2d_tiled_compute_at_ji(n: size, consumer: i8[n, n] @ DRAM,
sin: i8[n + 1, n + 1] @ DRAM):
def blur2d_tiled_compute_at_ji(n: size, consumer: i8[n, n] @ DRAM, sin: i8[n + 1, n + 1] @ DRAM):
assert n % 4 == 0
producer: i8[1 + n, 1 + n] @ DRAM
for i in seq(0, n / 4):
Expand All @@ -80,17 +59,11 @@ def blur2d_tiled_compute_at_ji(n: size, consumer: i8[n, n] @ DRAM,
for ji in seq(0, 4):
for jii in seq(0, 2):
for iii in seq(0, 2):
producer[ii + iii + 4 * i,
ji + jii + 4 * j] = sin[ii + iii + 4 * i,
ji + jii + 4 * j]
consumer[ii + 4 * i, ji + 4 * j] = (
producer[ii + 4 * i, ji + 4 * j] +
producer[ii + 4 * i, 1 + ji + 4 * j] +
producer[1 + ii + 4 * i, ji + 4 * j] +
producer[1 + ii + 4 * i, 1 + ji + 4 * j]) / 4.0
producer[ii + iii + 4 * i, ji + jii + 4 * j] = sin[ii + iii + 4 * i, ji + jii + 4 * j]
consumer[ii + 4 * i, ji + 4 * j] = (producer[ii + 4 * i, ji + 4 * j] + producer[ii + 4 * i, 1 + ji + 4 * j] +
producer[1 + ii + 4 * i, ji + 4 * j] + producer[1 + ii + 4 * i, 1 + ji + 4 * j]) / 4.0

def blur2d_tiled_compute_at_and_store_at_ji(n: size, consumer: i8[n, n] @ DRAM,
sin: i8[n + 1, n + 1] @ DRAM):
def blur2d_tiled_compute_at_and_store_at_ji(n: size, consumer: i8[n, n] @ DRAM, sin: i8[n + 1, n + 1] @ DRAM):
assert n % 4 == 0
for i in seq(0, n / 4):
for j in seq(0, n / 4):
Expand All @@ -99,8 +72,5 @@ def blur2d_tiled_compute_at_and_store_at_ji(n: size, consumer: i8[n, n] @ DRAM,
producer: i8[2, 2] @ DRAM
for jii in seq(0, 2):
for iii in seq(0, 2):
producer[iii, jii] = sin[ii + iii + 4 * i,
ji + jii + 4 * j]
consumer[ii + 4 * i, ji +
4 * j] = (producer[0, 0] + producer[0, 1] +
producer[1, 0] + producer[1, 1]) / 4.0
producer[iii, jii] = sin[ii + iii + 4 * i, ji + jii + 4 * j]
consumer[ii + 4 * i, ji + 4 * j] = (producer[0, 0] + producer[0, 1] + producer[1, 0] + producer[1, 1]) / 4.0
9 changes: 3 additions & 6 deletions tests/golden/test_im2col/test_im2col.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
def im2col(C: size, W: size, R: size, x: R[C, W] @ DRAM,
y: R[C + 1, R + 1, W + 1] @ DRAM):
def im2col(C: size, W: size, R: size, x: R[C, W] @ DRAM, y: R[C + 1, R + 1, W + 1] @ DRAM):
for c in seq(0, C):
for r in seq(0, R):
for i in seq(0, W):
if 0 <= i - r:
y[c, r, i] = x[c, i - r]
def matmul(K: size, C: size, W: size, R: size, w: R[K, C, R] @ DRAM,
res: R[K, W] @ DRAM, y: R[C + 1, R + 1, W + 1] @ DRAM):
def matmul(K: size, C: size, W: size, R: size, w: R[K, C, R] @ DRAM, res: R[K, W] @ DRAM, y: R[C + 1, R + 1, W + 1] @ DRAM):
for k in seq(0, K):
for c in seq(0, C):
for r in seq(0, R):
for i in seq(0, W):
if 0 <= i - r:
res[k, i] += w[k, c, r] * y[c, r, i]
def im2col_conv(K: size, C: size, W: size, R: size, w: R[K, C, R] @ DRAM,
x: R[C, W] @ DRAM, res: R[K, W] @ DRAM):
def im2col_conv(K: size, C: size, W: size, R: size, w: R[K, C, R] @ DRAM, x: R[C, W] @ DRAM, res: R[K, W] @ DRAM):
for k_init in seq(0, K):
for i_init in seq(0, W):
res[k_init, i_init] = 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ def simple_math_neon_sched(n: size, x: R[n] @ DRAM, y: R[n] @ DRAM):
neon_vst_4xf32(x[4 * io:4 + 4 * io], xVec[0:4])
if n % 4 > 0:
for ii in seq(0, n % 4):
x[ii + n / 4 *
4] = x[ii + n / 4 * 4] * y[ii + n / 4 * 4] * y[ii + n / 4 * 4]
x[ii + n / 4 * 4] = x[ii + n / 4 * 4] * y[ii + n / 4 * 4] * y[ii + n / 4 * 4]
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@ def foo(N: size, x: R[N, N] @ DRAM):
for j in seq(0, N):
for i in seq(0, N):
if 0 < i and i < N - 1 and (0 < j and j < N - 1):
x[i, j] += -1.0 / 4.0 * (x[i - 1, j] + x[i + 1, j] +
x[i, j - 1] + x[i, j + 1])
x[i, j] += -1.0 / 4.0 * (x[i - 1, j] + x[i + 1, j] + x[i, j - 1] + x[i, j + 1])
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@ def foo(N: size, x: R[N, N] @ DRAM):
for j in seq(0, N):
for i in seq(0, N):
if i > 0 and j > 0:
x[i, j] += -1.0 / 3.0 * (x[i - 1, j] + x[i - 1, j - 1] +
x[i, j - 1])
x[i, j] += -1.0 / 3.0 * (x[i - 1, j] + x[i - 1, j - 1] + x[i, j - 1])
3 changes: 1 addition & 2 deletions tests/golden/test_schedules/test_cut_loop_syrk.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
def SYRK(M: size, K: size, A: [f32][M, K] @ DRAM, A_t: [f32][M, K] @ DRAM,
C: [f32][M, M] @ DRAM):
def SYRK(M: size, K: size, A: [f32][M, K] @ DRAM, A_t: [f32][M, K] @ DRAM, C: [f32][M, M] @ DRAM):
assert M >= 1
assert K >= 1
assert stride(A, 1) == 1
Expand Down
6 changes: 2 additions & 4 deletions tests/golden/test_schedules/test_extract_subproc7.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
def gemv(m: size, n: size, alpha: R @ DRAM, beta: R @ DRAM,
A: [R][m, n] @ DRAM, x: [R][n] @ DRAM, y: [R][m] @ DRAM):
def gemv(m: size, n: size, alpha: R @ DRAM, beta: R @ DRAM, A: [R][m, n] @ DRAM, x: [R][n] @ DRAM, y: [R][m] @ DRAM):
assert stride(A, 1) == 1
for i in seq(0, m):
y[i] = y[i] * beta
for j in seq(0, n):
fooooo(m, n, alpha, A, x, y, j)
def fooooo(m: size, n: size, alpha: R @ DRAM, A: [R][m, n] @ DRAM,
x: [R][n] @ DRAM, y: [R][m] @ DRAM, j: index):
def fooooo(m: size, n: size, alpha: R @ DRAM, A: [R][m, n] @ DRAM, x: [R][n] @ DRAM, y: [R][m] @ DRAM, j: index):
assert stride(A, 1) == 1
assert 0 <= j
assert j < n
Expand Down
Loading
Loading