Skip to content

Commit f75b1ba

Browse files
New Feature: Kerngraph support for Relin (#109)
Refactored kerngraph to support grouping of p-isa lists to enable flexible reordering. Added regression testing for Relin Kernel. Co-authored-by: Jose Rojas Chaves <jose.rojas.chaves@intel.com>
1 parent 06a5a25 commit f75b1ba

File tree

5 files changed

+830
-50
lines changed

5 files changed

+830
-50
lines changed

p-isa_tools/kerngen/kernel_optimization/loops.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,36 @@
1010
from high_parser.pisa_operations import PIsaOp, Comment
1111

1212

13+
class PIsaOpGroup:
14+
"""A group of PIsaOp instructions with reorderable flag.
15+
16+
Attributes:
17+
pisa_list: List of PIsaOp instructions
18+
is_reorderable: Boolean indicating if the instructions can be reordered
19+
"""
20+
21+
pisa_list: list[PIsaOp]
22+
is_reorderable: bool
23+
24+
def __init__(self, pisa_list: list[PIsaOp], is_reorderable: bool = False):
25+
"""Initialize PIsaOpGroup.
26+
27+
Args:
28+
pisa_list: List of PIsaOp instructions
29+
is_reorderable: Boolean indicating if instructions can be reordered
30+
"""
31+
self.pisa_list = pisa_list
32+
self.is_reorderable = is_reorderable
33+
34+
def __len__(self) -> int:
35+
"""Return the number of instructions in the group."""
36+
return len(self.pisa_list)
37+
38+
def __iter__(self):
39+
"""Allow iteration over the PIsaOp instructions."""
40+
return iter(self.pisa_list)
41+
42+
1343
def remove_comments(pisa_list: list[PIsaOp]) -> list[PIsaOp]:
1444
"""Remove comments from a list of PIsaOp instructions.
1545
@@ -22,40 +52,49 @@ def remove_comments(pisa_list: list[PIsaOp]) -> list[PIsaOp]:
2252
return [pisa for pisa in pisa_list if not isinstance(pisa, Comment)]
2353

2454

25-
def split_by_reorderable(pisa_list: list[PIsaOp]) -> tuple[list[PIsaOp], list[PIsaOp]]:
55+
def split_by_reorderable(pisa_list: list[PIsaOp]) -> list[PIsaOpGroup]:
2656
"""Split a list of PIsaOp instructions into reorderable and non-reorderable groups.
2757
2858
Args:
2959
pisa_list: List of PIsaOp instructions
3060
3161
Returns:
32-
Tuple containing two lists:
33-
- reorderable: Instructions that can be reordered
34-
- non_reorderable: Instructions that cannot be reordered
62+
List of PIsaOpGroup objects containing grouped instructions with their reorderable status
3563
"""
36-
37-
reorderable = []
38-
non_reorderable = []
39-
is_reorderable = False
64+
groups = []
65+
current_group = PIsaOpGroup([], is_reorderable=False)
66+
no_reoderable_group = True
4067

4168
for pisa in pisa_list:
4269
# if the pisa is a comment and it contains <reorderable> tag, treat the following pisa as reorderable until a </reorderable> tag is found.
4370
if isinstance(pisa, Comment):
4471
if "<reorderable>" in pisa.line:
45-
is_reorderable = True
72+
# If current group has instructions, append it to groups first
73+
if current_group.pisa_list:
74+
groups.append(current_group)
75+
# Create a new reorderable group
76+
current_group = PIsaOpGroup([], is_reorderable=True)
77+
no_reoderable_group = False
4678
elif "</reorderable>" in pisa.line:
47-
is_reorderable = False
48-
49-
if is_reorderable:
50-
reorderable.append(pisa)
79+
# End reorderable section, append current group to groups
80+
if current_group.pisa_list:
81+
groups.append(current_group)
82+
# Create a new non-reorderable group
83+
current_group = PIsaOpGroup([], is_reorderable=False)
5184
else:
52-
non_reorderable.append(pisa)
85+
# Add non-comment instruction to current group
86+
current_group.pisa_list.append(pisa)
87+
88+
# Add any remaining instructions in current_group
89+
if current_group.pisa_list:
90+
groups.append(current_group)
91+
92+
# If there are no reorderable groups, set reorderable to True for the entire groups
93+
if no_reoderable_group:
94+
for group in groups:
95+
group.is_reorderable = True
5396

54-
# if reoderable is empty, return non_reorderable as reorderable
55-
if not reorderable:
56-
reorderable = non_reorderable
57-
non_reorderable = []
58-
return remove_comments(reorderable), remove_comments(non_reorderable)
97+
return groups
5998

6099

61100
def loop_interchange(

p-isa_tools/kerngen/kerngraph.py

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def parse_args():
5252
nargs="*",
5353
default=[],
5454
# Composition high ops such are ntt, mod, and relin are not currently supported
55-
choices=["add", "sub", "mul", "muli", "copy", "ntt", "intt", "mod"],
55+
choices=["add", "sub", "mul", "muli", "copy", "ntt", "intt", "mod", "relin"],
5656
help="List of high_op names",
5757
)
5858
parser.add_argument(
@@ -78,47 +78,67 @@ def parse_args():
7878
return parser.parse_args()
7979

8080

81-
def main(args):
82-
"""Main function to read input and parse each line with KernelParser."""
83-
input_lines = sys.stdin.read().strip().splitlines()
81+
def parse_kernels(input_lines, debug=False):
82+
"""Parse kernel strings from input lines."""
8483
valid_kernels = []
85-
Config.legacy_mode = args.legacy
86-
8784
for line in input_lines:
8885
try:
8986
kernel = KernelParser.parse_kernel(line)
9087
valid_kernels.append(kernel)
9188
except ValueError as e:
92-
if args.debug:
89+
if debug:
9390
print(f"Error parsing line: {line}\nReason: {e}")
9491
continue # Skip invalid lines
92+
return valid_kernels
93+
94+
95+
def process_kernel_with_reordering(kernel, args):
96+
"""Process a kernel with reordering optimization."""
97+
groups = split_by_reorderable(kernel.to_pisa())
98+
processed_kernel = []
99+
for group in groups:
100+
if group.is_reorderable:
101+
processed_kernel.append(
102+
loop_interchange(
103+
group.pisa_list,
104+
primary_key=args.primary,
105+
secondary_key=args.secondary,
106+
)
107+
)
108+
else:
109+
processed_kernel.append(group.pisa_list)
110+
111+
for pisa in mixed_to_pisa_ops(processed_kernel):
112+
print(pisa)
113+
114+
115+
def should_apply_reordering(kernel, targets):
116+
"""Check if reordering should be applied to this kernel."""
117+
return targets and any(target.lower() in str(kernel).lower() for target in targets)
118+
119+
120+
def main(args):
121+
"""Main function to read input and parse each line with KernelParser."""
122+
input_lines = sys.stdin.read().strip().splitlines()
123+
Config.legacy_mode = args.legacy
124+
125+
valid_kernels = parse_kernels(input_lines, args.debug)
95126

96127
if not valid_kernels:
97128
print("No valid kernel strings were parsed.")
98-
else:
99-
if args.debug:
100-
print(
101-
f"# Reordered targets {args.target} with primary key {args.primary} and secondary key {args.secondary}"
102-
)
103-
for kernel in valid_kernels:
104-
if args.target and any(
105-
target.lower() in str(kernel).lower() for target in args.target
106-
):
107-
reorderable, non_reorderable = split_by_reorderable(kernel.to_pisa())
108-
kernel = non_reorderable
109-
kernel.append(
110-
loop_interchange(
111-
reorderable,
112-
primary_key=args.primary,
113-
secondary_key=args.secondary,
114-
)
115-
)
116-
117-
for pisa in mixed_to_pisa_ops(kernel):
118-
print(pisa)
119-
else:
120-
for pisa in kernel.to_pisa():
121-
print(pisa)
129+
return
130+
131+
if args.debug:
132+
print(
133+
f"# Reordered targets {args.target} with primary key {args.primary} and secondary key {args.secondary}"
134+
)
135+
136+
for kernel in valid_kernels:
137+
if should_apply_reordering(kernel, args.target):
138+
process_kernel_with_reordering(kernel, args)
139+
else:
140+
for pisa in kernel.to_pisa():
141+
print(pisa)
122142

123143

124144
if __name__ == "__main__":

0 commit comments

Comments
 (0)