Skip to content

Commit 85caf92

Browse files
Move dynamic FileCheck test generation to pytest
1 parent 52be406 commit 85caf92

21 files changed

+1437
-1440
lines changed

conftest.py

Lines changed: 323 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,332 @@
11
import os
22
import distutils.core
33
import subprocess
4+
import json
5+
import jinja2
6+
import itertools
7+
import random
8+
import re
9+
import pathlib
10+
import time
11+
import zlib
412

13+
from typing import Any, Dict, List, Tuple, Generator, Callable, Literal, Union, Optional
514

6-
def pytest_configure(config):
7-
distutils.core.run_setup(
8-
"./setup.py", script_args=["build_ext", "--inplace"], stop_after="run"
15+
##########
16+
# Config #
17+
##########
18+
19+
20+
def pytest_configure(config: "_pytest.config.Config"):
21+
if not hasattr(config, "workerinput"): # only run once on main process
22+
distutils.core.run_setup(
23+
"./setup.py", script_args=["build_ext", "--inplace"], stop_after="run"
24+
)
25+
# Ensure graphblas-opt is built
26+
subprocess.run(["python", os.path.join("mlir_graphblas", "src", "build.py")])
27+
28+
return
29+
30+
31+
def pytest_addoption(parser: "_pytest.config.argparsing.Parser"):
32+
parser.addoption(
33+
"--filecheck-sampling",
34+
action="store",
35+
default="bucket",
36+
help="Method to sample the space of the templatized FileCheck tests.",
37+
)
38+
return
39+
40+
41+
def pytest_generate_tests(metafunc: "_pytest.python.Metafunc"):
42+
if metafunc.function.__name__ == "test_filecheck_mlir":
43+
parameterize_templatized_filecheck_test(metafunc)
44+
return
45+
46+
47+
###########################################
48+
# Parameterization of test_filecheck_mlir #
49+
###########################################
50+
51+
NAMED_PARAMETER_VALUE_CHOICES = {
52+
"STANDARD_ELEMENT_TYPES": [
53+
"i1",
54+
"i4",
55+
"i8",
56+
"i16",
57+
"i32",
58+
"i64",
59+
"f16",
60+
"bf16",
61+
"f32",
62+
"f64",
63+
"f80",
64+
"f128",
65+
],
66+
"MATRIX_APPLY_OPERATORS": ["min"],
67+
"SPARSITY_TYPES": ["dense", "compressed", "singleton"],
68+
"MATRIX_MULTIPLY_SEMIRINGS": ["plus_times", "plus_pair", "plus_plus"],
69+
}
70+
71+
PREFILTER_FUNCTIONS = {None: lambda *args: True}
72+
73+
# Prefilter Functions
74+
75+
76+
def prefilter_func(
77+
func: Callable[[Dict[str, str]], bool]
78+
) -> Callable[[Dict[str, str]], bool]:
79+
if func.__name__ in PREFILTER_FUNCTIONS:
80+
raise Exception(
81+
f"{repr(func.__name__)} is already in use as prefilter function name."
82+
)
83+
PREFILTER_FUNCTIONS[func.__name__] = func
84+
return func
85+
86+
87+
@prefilter_func
88+
def different_thunk_and_element_type(parameter_dict: Dict[str, str]) -> bool:
89+
return parameter_dict["element_type"] != parameter_dict["thunk_type"]
90+
91+
92+
@prefilter_func
93+
def sparse_but_not_compressed_sparse(parameter_dict: Dict[str, str]) -> bool:
94+
sparsity0 = parameter_dict["sparsity0"]
95+
sparsity1 = parameter_dict["sparsity1"]
96+
return "compressed" in (sparsity0, sparsity1) and (
97+
(sparsity0, sparsity1) != ("dense", "compressed")
998
)
1099

11-
# Ensure graphblas-opt is built
12-
subprocess.run(["python", os.path.join("mlir_graphblas", "src", "build.py")])
13100

101+
# Template Expansion
102+
103+
104+
def lazy_list_shuffler(input_list: list) -> Generator[Any, None, None]:
105+
"""
106+
This generator yields an unseen item at random from the input_list.
107+
Upon each call to the generator, it'll randomly select an index
108+
until an unseen index is found.
109+
110+
Running this generator to exhaustion is approximately O(n**2).
111+
This should only be used in cases where the generator is NOT expected
112+
to be exhausted and is expected to run a few times, in
113+
which case the expected running time is Theta(1).
114+
"""
115+
seen_indices = set()
116+
for _ in range(len(input_list)):
117+
while (index := random.randint(0, len(input_list) - 1)) in seen_indices:
118+
pass
119+
seen_indices.add(index)
120+
yield input_list[index]
121+
return
122+
123+
124+
def parameter_tuples_from_templates(
125+
sampling: Union[int, float, Literal["exhaustive", "bucket"]],
126+
seed: Optional[int] = None,
127+
) -> List[Tuple[str, Tuple[str, str, Dict[str, str]]]]:
128+
"""
129+
Returns a list of tuples of the form (
130+
"mlir_code",
131+
"test command jinja template",
132+
"/template/file/location",
133+
{"template_parameter_0": "template_parameter_value_0", ...}
134+
)
135+
136+
If sampling is a float (in the range [0, 1]), the number of parameter
137+
tuples returned will be (total_num_possible_cases * sampling) for each
138+
template.
139+
140+
If sampling is an int, the number of parameter tuples returned will
141+
be (sampling) for each template.
142+
143+
If sampling is "exhaustive", we will return all possible parameter
144+
tuples for each template.
145+
146+
If sampling is "bucket", we will yield parameter tuples as follows:
147+
For each template parameter name:
148+
Randomly sample one template parameter value for all other
149+
template parameter names
150+
Yield the pytest param for the set of template parameter
151+
"""
152+
if seed is not None:
153+
random.seed(seed)
154+
parameter_tuples = []
155+
current_module_dir = pathlib.Path(__file__).parent.absolute()
156+
template_files = (
157+
os.path.join(root, f)
158+
for root, _, files in os.walk(current_module_dir, followlinks=True)
159+
for f in files
160+
if f.endswith(".template.mlir")
161+
)
162+
for template_file in template_files:
163+
# Parse the template file
164+
with open(template_file, "r") as f:
165+
json_sting, mlir_template = f.read().split("### START TEST ###")
166+
test_spec: dict = json.loads(json_sting)
167+
mlir_template = jinja2.Template(mlir_template, undefined=jinja2.StrictUndefined)
168+
if "parameters" not in test_spec:
169+
raise ValueError(
170+
f"{template_file} does not contain a valid test specification as "
171+
'it does not specify a value for the key "parameters".'
172+
)
173+
elif "run" not in test_spec or not isinstance(test_spec["run"], str):
174+
raise ValueError(
175+
f"{template_file} does not contain a valid test specification as "
176+
'it does not specify a valid value for the key "run".'
177+
)
178+
prefilter_name = test_spec.get("prefilter")
179+
parameter_dict_filter = PREFILTER_FUNCTIONS.get(prefilter_name)
180+
if parameter_dict_filter is None:
181+
raise NameError(f"Unknown prefilter function named {repr(prefilter_name)}.")
182+
183+
# Grab test running command
184+
test_execution_command_template = test_spec["run"]
185+
186+
# Grab parameter choices
187+
parameter_choices: Dict[str, List[str]] = dict()
188+
for parameter_name, parameter_value_choices in test_spec["parameters"].items():
189+
if (
190+
isinstance(parameter_value_choices, str)
191+
and parameter_value_choices in NAMED_PARAMETER_VALUE_CHOICES
192+
):
193+
parameter_choices[parameter_name] = NAMED_PARAMETER_VALUE_CHOICES[
194+
parameter_value_choices
195+
]
196+
elif isinstance(parameter_value_choices, list):
197+
parameter_choices[parameter_name] = parameter_value_choices
198+
else:
199+
raise ValueError(
200+
f"{repr(parameter_value_choices)} does not specify a valid set of parameter values."
201+
)
202+
203+
# Handle each sampling case separately
204+
if sampling == "bucket":
205+
parameter_dicts = []
206+
for parameter_name, parameter_possible_values in parameter_choices.items():
207+
for parameter_value in parameter_possible_values:
208+
# Set up lazy iterators to randomly grab the values of all the other parameters
209+
other_parameter_names = [
210+
name
211+
for name in parameter_choices.keys()
212+
if name != parameter_name
213+
]
214+
other_parameter_choices_values = (
215+
parameter_choices[name] for name in other_parameter_names
216+
)
217+
other_parameter_choices_values = map(
218+
lazy_list_shuffler, other_parameter_choices_values
219+
)
220+
other_parameter_value_tuples = itertools.product(
221+
*other_parameter_choices_values
222+
)
223+
other_parameter_dicts = (
224+
dict(zip(other_parameter_names, other_parameter_values))
225+
for other_parameter_values in other_parameter_value_tuples
226+
)
227+
# Go through possible parameter dicts until we find a valid one
228+
for parameter_dict in other_parameter_dicts:
229+
parameter_dict[parameter_name] = parameter_value
230+
if parameter_dict_filter(parameter_dict):
231+
parameter_tuples.append(
232+
(
233+
generate_test_id_string(
234+
template_file, parameter_dict
235+
),
236+
(
237+
mlir_template.render(**parameter_dict),
238+
test_execution_command_template,
239+
template_file,
240+
parameter_dict,
241+
),
242+
)
243+
)
244+
break
245+
else:
246+
if isinstance(sampling, int):
247+
248+
def sampling_method(parameter_dicts):
249+
parameter_dicts = list(parameter_dicts)
250+
num_samples = min(sampling, len(parameter_dicts))
251+
return random.sample(parameter_dicts, num_samples)
252+
253+
elif isinstance(sampling, float):
254+
if sampling < 0 or sampling > 1:
255+
raise ValueError(
256+
f"Portion of parameter dicts to sample must be "
257+
f"in the range [0, 1], got {sampling}."
258+
)
259+
260+
def sampling_method(parameter_dicts):
261+
parameter_dicts = list(parameter_dicts)
262+
num_samples = int(len(parameter_dicts) * sampling)
263+
return random.sample(parameter_dicts, num_samples)
264+
265+
elif sampling == "exhaustive":
266+
267+
def sampling_method(parameter_dicts):
268+
return parameter_dicts
269+
270+
else:
271+
raise ValueError(
272+
f"{repr(sampling)} is not a supported sampling method."
273+
)
274+
275+
# Grab all possible parameter dicts
276+
parameter_names = parameter_choices.keys()
277+
parameter_value_tuples = itertools.product(*parameter_choices.values())
278+
all_parameter_dicts = (
279+
dict(zip(parameter_names, parameter_values))
280+
for parameter_values in parameter_value_tuples
281+
)
282+
283+
# Find the requested parameter dicts
284+
parameter_dicts = filter(parameter_dict_filter, all_parameter_dicts)
285+
parameter_dicts = sampling_method(parameter_dicts)
286+
287+
# Append one parameter dict for each test case
288+
for parameter_dict in parameter_dicts:
289+
parameter_tuples.append(
290+
(
291+
generate_test_id_string(template_file, parameter_dict),
292+
(
293+
mlir_template.render(**parameter_dict),
294+
test_execution_command_template,
295+
template_file,
296+
parameter_dict,
297+
),
298+
)
299+
)
300+
301+
return parameter_tuples
302+
303+
304+
def generate_test_id_string(template_file: str, parameter_dict: Dict[str, str]) -> str:
305+
return "".join(c for c in template_file if c.isalnum()) + "".join(
306+
f"({re.escape(key)}:{re.escape(parameter_dict[key])})"
307+
for key in sorted(parameter_dict.keys())
308+
)
309+
310+
311+
def parameterize_templatized_filecheck_test(metafunc: "_pytest.python.Metafunc"):
312+
sampling_method_string = metafunc.config.getoption("--filecheck-sampling")
313+
if sampling_method_string.isdigit():
314+
sampling = int(sampling_method_string)
315+
elif re.match("^\d+(\.\d+)?$", sampling_method_string):
316+
sampling = float(sampling_method_string)
317+
else:
318+
sampling = sampling_method_string
319+
seed = (
320+
zlib.adler32(metafunc.config.workerinput["testrunuid"].encode())
321+
if hasattr(metafunc.config, "workerinput")
322+
else time.time()
323+
)
324+
ids, parameter_values = zip(
325+
*parameter_tuples_from_templates(sampling, seed)
326+
)
327+
metafunc.parametrize(
328+
["mlir_code", "test_command_template", "template_file", "parameter_dict"],
329+
parameter_values,
330+
ids=ids,
331+
)
14332
return

dev-environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies:
1010
- coverage
1111
- pytest
1212
- pytest-cov
13+
- pytest-xdist
1314
- black
1415

1516
# documentation

0 commit comments

Comments
 (0)