Skip to content

Commit ccff670

Browse files
Initial implementation of dynamic test generation
This commit contains the initial implementation of dynamic test generation for our graphblas-opt tests. We now have a framework for writing parameterized tests that generate *mlir files at test-running time.
1 parent 22d7846 commit ccff670

19 files changed

+1628
-88
lines changed

conftest.py

Lines changed: 330 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,338 @@
11
import os
22
import distutils.core
33
import subprocess
4+
import json
5+
import jinja2
6+
import itertools
7+
import random
8+
import re
9+
import pathlib
10+
import time
11+
import zlib
412

13+
from typing import Any, Dict, List, Tuple, Generator, Callable, Literal, Union, Optional
514

6-
def pytest_configure(config):
7-
distutils.core.run_setup(
8-
"./setup.py", script_args=["build_ext", "--inplace"], stop_after="run"
15+
##########
16+
# Config #
17+
##########
18+
19+
20+
def pytest_configure(config: "_pytest.config.Config"):
21+
if not hasattr(config, "workerinput"): # only run once on main process
22+
distutils.core.run_setup(
23+
"./setup.py", script_args=["build_ext", "--inplace"], stop_after="run"
24+
)
25+
# Ensure graphblas-opt is built
26+
subprocess.run(["python", os.path.join("mlir_graphblas", "src", "build.py")])
27+
28+
return
29+
30+
31+
def pytest_addoption(parser: "_pytest.config.argparsing.Parser"):
32+
parser.addoption(
33+
"--filecheck-sampling",
34+
action="store",
35+
default="default",
36+
help="Method to sample the space of the templatized FileCheck tests.",
37+
)
38+
return
39+
40+
41+
def pytest_generate_tests(metafunc: "_pytest.python.Metafunc"):
42+
if metafunc.function.__name__ == "test_filecheck_mlir":
43+
parameterize_templatized_filecheck_test(metafunc)
44+
return
45+
46+
47+
###########################################
48+
# Parameterization of test_filecheck_mlir #
49+
###########################################
50+
51+
NAMED_PARAMETER_VALUE_CHOICES = {
52+
"STANDARD_ELEMENT_TYPES": [
53+
"i1",
54+
"i4",
55+
"i8",
56+
"i16",
57+
"i32",
58+
"i64",
59+
"f16",
60+
"bf16",
61+
"f32",
62+
"f64",
63+
"f80",
64+
"f128",
65+
],
66+
"MATRIX_APPLY_OPERATORS": ["min"],
67+
"SPARSITY_TYPES": ["dense", "compressed", "singleton"],
68+
"MATRIX_MULTIPLY_SEMIRINGS": ["plus_times", "plus_pair", "plus_plus"],
69+
}
70+
71+
PREFILTER_FUNCTIONS = {None: lambda *args: True}
72+
73+
# Prefilter Functions
74+
75+
76+
def prefilter_func(
77+
func: Callable[[Dict[str, str]], bool]
78+
) -> Callable[[Dict[str, str]], bool]:
79+
if func.__name__ in PREFILTER_FUNCTIONS:
80+
raise Exception(
81+
f"{repr(func.__name__)} is already in use as prefilter function name."
82+
)
83+
PREFILTER_FUNCTIONS[func.__name__] = func
84+
return func
85+
86+
87+
@prefilter_func
88+
def different_thunk_and_element_type(parameter_dict: Dict[str, str]) -> bool:
89+
return parameter_dict["element_type"] != parameter_dict["thunk_type"]
90+
91+
92+
@prefilter_func
93+
def sparse_but_not_compressed_sparse(parameter_dict: Dict[str, str]) -> bool:
94+
sparsity0 = parameter_dict["sparsity0"]
95+
sparsity1 = parameter_dict["sparsity1"]
96+
return "compressed" in (sparsity0, sparsity1) and (
97+
(sparsity0, sparsity1) != ("dense", "compressed")
98+
)
99+
100+
101+
# Template Expansion
102+
103+
104+
def lazy_list_shuffler(input_list: list) -> Generator[Any, None, None]:
105+
"""
106+
This generator yields an unseen item at random from the input_list.
107+
Upon each call to the generator, it'll randomly select an index
108+
until an unseen index is found.
109+
110+
Running this generator to exhaustion is approximately O(n**2).
111+
This should only be used in cases where the generator is NOT expected
112+
to be exhausted and is expected to run a few times, in
113+
which case the expected running time is Theta(1).
114+
"""
115+
seen_indices = set()
116+
for _ in range(len(input_list)):
117+
while (index := random.randint(0, len(input_list) - 1)) in seen_indices:
118+
pass
119+
seen_indices.add(index)
120+
yield input_list[index]
121+
return
122+
123+
124+
def parameter_tuples_from_templates(
125+
sampling: Union[int, float, Literal["exhaustive", "bucket"]],
126+
seed: Optional[int] = None,
127+
) -> List[Tuple[str, Tuple[str, str, Dict[str, str]]]]:
128+
"""
129+
Returns a list of tuples of the form (
130+
"mlir_code",
131+
"test command jinja template",
132+
"/template/file/location",
133+
{"template_parameter_0": "template_parameter_value_0", ...}
134+
)
135+
136+
If sampling is a float (in the range [0, 1]), the number of parameter
137+
tuples returned will be (total_num_possible_cases * sampling) for each
138+
template.
139+
140+
If sampling is an int, the number of parameter tuples returned will
141+
be (sampling) for each template.
142+
143+
If sampling is "exhaustive", we will return all possible parameter
144+
tuples for each template.
145+
146+
If sampling is "bucket", we will yield parameter tuples as follows:
147+
For each template parameter name:
148+
Randomly sample one template parameter value for all other
149+
template parameter names
150+
Yield the pytest param for the set of template parameter
151+
"""
152+
if seed is not None:
153+
random.seed(seed)
154+
parameter_tuples = []
155+
current_module_dir = pathlib.Path(__file__).parent.absolute()
156+
template_files = (
157+
os.path.join(root, f)
158+
for root, _, files in os.walk(current_module_dir, followlinks=True)
159+
for f in files
160+
if f.endswith(".template.mlir")
161+
)
162+
for template_file in template_files:
163+
# Parse the template file
164+
with open(template_file, "r") as f:
165+
json_sting, mlir_template = f.read().split("### START TEST ###")
166+
test_spec: dict = json.loads(json_sting)
167+
mlir_template = jinja2.Template(mlir_template, undefined=jinja2.StrictUndefined)
168+
if "parameters" not in test_spec:
169+
raise ValueError(
170+
f"{template_file} does not contain a valid test specification as "
171+
'it does not specify a value for the key "parameters".'
172+
)
173+
elif "run" not in test_spec or not isinstance(test_spec["run"], str):
174+
raise ValueError(
175+
f"{template_file} does not contain a valid test specification as "
176+
'it does not specify a valid value for the key "run".'
177+
)
178+
prefilter_name = test_spec.get("prefilter")
179+
parameter_dict_filter = PREFILTER_FUNCTIONS.get(prefilter_name)
180+
if parameter_dict_filter is None:
181+
raise NameError(f"Unknown prefilter function named {repr(prefilter_name)}.")
182+
183+
# Grab test running command
184+
test_execution_command_template = test_spec["run"]
185+
186+
# Grab parameter choices
187+
parameter_choices: Dict[str, List[str]] = dict()
188+
for parameter_name, parameter_value_choices in test_spec["parameters"].items():
189+
if (
190+
isinstance(parameter_value_choices, str)
191+
and parameter_value_choices in NAMED_PARAMETER_VALUE_CHOICES
192+
):
193+
parameter_choices[parameter_name] = NAMED_PARAMETER_VALUE_CHOICES[
194+
parameter_value_choices
195+
]
196+
elif isinstance(parameter_value_choices, list):
197+
parameter_choices[parameter_name] = parameter_value_choices
198+
else:
199+
raise ValueError(
200+
f"{repr(parameter_value_choices)} does not specify a valid set of parameter values."
201+
)
202+
203+
# Handle each sampling case separately
204+
if sampling == "bucket":
205+
parameter_dicts = []
206+
for parameter_name, parameter_possible_values in parameter_choices.items():
207+
for parameter_value in parameter_possible_values:
208+
# Set up lazy iterators to randomly grab the values of all the other parameters
209+
other_parameter_names = [
210+
name
211+
for name in parameter_choices.keys()
212+
if name != parameter_name
213+
]
214+
other_parameter_choices_values = (
215+
parameter_choices[name] for name in other_parameter_names
216+
)
217+
other_parameter_choices_values = map(
218+
lazy_list_shuffler, other_parameter_choices_values
219+
)
220+
other_parameter_value_tuples = itertools.product(
221+
*other_parameter_choices_values
222+
)
223+
other_parameter_dicts = (
224+
dict(zip(other_parameter_names, other_parameter_values))
225+
for other_parameter_values in other_parameter_value_tuples
226+
)
227+
# Go through possible parameter dicts until we find a valid one
228+
for parameter_dict in other_parameter_dicts:
229+
parameter_dict[parameter_name] = parameter_value
230+
if parameter_dict_filter(parameter_dict):
231+
parameter_tuples.append(
232+
(
233+
generate_test_id_string(
234+
template_file, parameter_dict
235+
),
236+
(
237+
mlir_template.render(**parameter_dict),
238+
test_execution_command_template,
239+
template_file,
240+
parameter_dict,
241+
),
242+
)
243+
)
244+
break
245+
else:
246+
if isinstance(sampling, int):
247+
248+
def sampling_method(parameter_dicts):
249+
parameter_dicts = list(parameter_dicts)
250+
num_samples = min(sampling, len(parameter_dicts))
251+
return random.sample(parameter_dicts, num_samples)
252+
253+
elif isinstance(sampling, float):
254+
if sampling < 0 or sampling > 1:
255+
raise ValueError(
256+
f"Portion of parameter dicts to sample must be "
257+
f"in the range [0, 1], got {sampling}."
258+
)
259+
260+
def sampling_method(parameter_dicts):
261+
parameter_dicts = list(parameter_dicts)
262+
num_samples = int(len(parameter_dicts) * sampling)
263+
return random.sample(parameter_dicts, num_samples)
264+
265+
elif sampling == "default":
266+
267+
def sampling_method(parameter_dicts):
268+
for parameter_dict in parameter_dicts:
269+
return [parameter_dict]
270+
return []
271+
272+
elif sampling == "exhaustive":
273+
274+
def sampling_method(parameter_dicts):
275+
return parameter_dicts
276+
277+
else:
278+
raise ValueError(
279+
f"{repr(sampling)} is not a supported sampling method."
280+
)
281+
282+
# Grab all possible parameter dicts
283+
parameter_names = parameter_choices.keys()
284+
parameter_value_tuples = itertools.product(*parameter_choices.values())
285+
all_parameter_dicts = (
286+
dict(zip(parameter_names, parameter_values))
287+
for parameter_values in parameter_value_tuples
288+
)
289+
290+
# Find the requested parameter dicts
291+
parameter_dicts = filter(parameter_dict_filter, all_parameter_dicts)
292+
parameter_dicts = sampling_method(parameter_dicts)
293+
294+
# Append one parameter dict for each test case
295+
for parameter_dict in parameter_dicts:
296+
parameter_tuples.append(
297+
(
298+
generate_test_id_string(template_file, parameter_dict),
299+
(
300+
mlir_template.render(**parameter_dict),
301+
test_execution_command_template,
302+
template_file,
303+
parameter_dict,
304+
),
305+
)
306+
)
307+
308+
return parameter_tuples
309+
310+
311+
def generate_test_id_string(template_file: str, parameter_dict: Dict[str, str]) -> str:
312+
return "".join(c for c in template_file if c.isalnum()) + "".join(
313+
f"({re.escape(key)}:{re.escape(parameter_dict[key])})"
314+
for key in sorted(parameter_dict.keys())
315+
)
316+
317+
318+
def parameterize_templatized_filecheck_test(metafunc: "_pytest.python.Metafunc"):
319+
sampling_method_string = metafunc.config.getoption("--filecheck-sampling")
320+
if sampling_method_string.isdigit():
321+
sampling = int(sampling_method_string)
322+
elif re.match("^\d+(\.\d+)?$", sampling_method_string):
323+
sampling = float(sampling_method_string)
324+
else:
325+
sampling = sampling_method_string
326+
seed = (
327+
zlib.adler32(metafunc.config.workerinput["testrunuid"].encode())
328+
if hasattr(metafunc.config, "workerinput")
329+
else time.time()
330+
)
331+
ids, parameter_values = zip(*parameter_tuples_from_templates(sampling, seed))
332+
metafunc.parametrize(
333+
["mlir_code", "test_command_template", "template_file", "parameter_dict"],
334+
parameter_values,
335+
ids=ids,
9336
)
10337

11338
# Ensure graphblas-opt is built

continuous_integration/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies:
1010
- coverage
1111
- pytest
1212
- pytest-cov
13+
- pytest-xdist
1314
- black
1415

1516

dev-environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies:
1010
- coverage
1111
- pytest
1212
- pytest-cov
13+
- pytest-xdist
1314
- black
1415

1516
# documentation

mlir_graphblas/mlir_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ class MLIRFunctionBuilder(BaseFunction):
135135
+ "{{statements}}"
136136
+ "\n"
137137
+ (" " * default_indentation_size)
138-
+ "}"
138+
+ "}",
139+
undefined=jinja2.StrictUndefined,
139140
)
140141

141142
def __init__(

mlir_graphblas/src/test/GraphBLAS/graphblas-opt.mlir

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)