Skip to content

Commit 7ff5cc9

Browse files
Move dynamic FileCheck test generation to pytest
1 parent 52be406 commit 7ff5cc9

21 files changed

+1520
-1292
lines changed

conftest.py

Lines changed: 317 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,326 @@
11
import os
22
import distutils.core
33
import subprocess
4+
import json
5+
import jinja2
6+
import itertools
7+
import random
8+
import re
9+
import pathlib
10+
import time
11+
import zlib
412

13+
from typing import Any, Dict, List, Tuple, Generator, Callable, Literal, Union, Optional
514

6-
def pytest_configure(config):
7-
distutils.core.run_setup(
8-
"./setup.py", script_args=["build_ext", "--inplace"], stop_after="run"
15+
##########
16+
# Config #
17+
##########
18+
19+
20+
def pytest_configure(config: "_pytest.config.Config"):
21+
if not hasattr(config, "workerinput"): # ony run once on main process
22+
distutils.core.run_setup(
23+
"./setup.py", script_args=["build_ext", "--inplace"], stop_after="run"
24+
)
25+
# Ensure graphblas-opt is built
26+
subprocess.run(["python", os.path.join("mlir_graphblas", "src", "build.py")])
27+
28+
return
29+
30+
31+
def pytest_addoption(parser: "_pytest.config.argparsing.Parser"):
32+
parser.addoption(
33+
"--filecheck-sampling",
34+
action="store",
35+
default="bucket",
36+
help="Method to sample the space of the templatized FileCheck tests.",
37+
)
38+
return
39+
40+
41+
def pytest_generate_tests(metafunc: "_pytest.python.Metafunc"):
42+
if metafunc.function.__name__ == "test_filecheck_mlir":
43+
parameterize_templatized_filecheck_test(metafunc)
44+
return
45+
46+
47+
###########################################
48+
# Parameterization of test_filecheck_mlir #
49+
###########################################
50+
51+
NAMED_PARAMETER_VALUE_CHOICES = {
52+
"STANDARD_ELEMENT_TYPES": [
53+
"i1",
54+
"i4",
55+
"i8",
56+
"i16",
57+
"i32",
58+
"i64",
59+
"f16",
60+
"bf16",
61+
"f32",
62+
"f64",
63+
"f80",
64+
"f128",
65+
],
66+
"MATRIX_APPLY_OPERATORS": ["min"],
67+
"SPARSITY_TYPES": ["dense", "compressed", "singleton"],
68+
"MATRIX_MULTIPLY_SEMIRINGS": ["plus_times", "plus_pair", "plus_plus"]
69+
}
70+
71+
PREFILTER_FUNCTIONS = {None: lambda *args: True}
72+
73+
# Prefilter Functions
74+
75+
def prefilter_func(func: Callable[[Dict[str, str]], bool]) -> Callable[[Dict[str, str]], bool]:
76+
if func.__name__ in PREFILTER_FUNCTIONS:
77+
raise Exception(f"{repr(func.__name__)} is already in use as prefilter function name.")
78+
PREFILTER_FUNCTIONS[func.__name__] = func
79+
return func
80+
81+
@prefilter_func
82+
def different_thunk_and_element_type(parameter_dict: Dict[str, str]) -> bool:
83+
return parameter_dict["element_type"] != parameter_dict["thunk_type"]
84+
85+
@prefilter_func
86+
def sparse_but_not_compressed_sparse(parameter_dict: Dict[str, str]) -> bool:
87+
sparsity0 = parameter_dict["sparsity0"]
88+
sparsity1 = parameter_dict["sparsity1"]
89+
return "compressed" in (sparsity0, sparsity1) and ((sparsity0, sparsity1) != ("dense", "compressed"))
90+
91+
# Template Expansion
92+
93+
def lazy_list_shuffler(input_list: list) -> Generator[Any, None, None]:
94+
"""
95+
This generator yields an unseen item at random from the input_list.
96+
Upon each call to the generator, it'll randomly select an index
97+
until an unseen index is found.
98+
99+
Running this generator to exhaustion is approximately O(n**2).
100+
This should only be used in cases where the generator is NOT expected
101+
to be exhausted and is expected to run a few times, in
102+
which case the expected running time is Theta(1).
103+
"""
104+
seen_indices = set()
105+
for _ in range(len(input_list)):
106+
while (index := random.randint(0, len(input_list) - 1)) in seen_indices:
107+
pass
108+
seen_indices.add(index)
109+
yield input_list[index]
110+
return
111+
112+
113+
def parameter_tuples_from_templates(
114+
sampling: Union[int, float, Literal["exhaustive", "bucket"]],
115+
seed: Optional[int] = None,
116+
) -> List[Tuple[str, Tuple[str, str, Dict[str, str]]]]:
117+
"""
118+
Returns a list of tuples of the form (
119+
"mlir_code",
120+
"test command jinja template",
121+
"/template/file/location",
122+
{"template_parameter_0": "template_parameter_value_0", ...}
123+
)
124+
125+
If sampling is a float (in the range [0, 1]), the number of parameter
126+
tuples returned will be (total_num_possible_cases * sampling) for each
127+
template.
128+
129+
If sampling is an int, the number of parameter tuples returned will
130+
be (sampling) for each template.
131+
132+
If sampling is "exhaustive", we will return all possible parameter
133+
tuples for each template.
134+
135+
If sampling is "bucket", we will yield parameter tuples as follows:
136+
For each template parameter name:
137+
Randomly sample one template parameter value for all other
138+
template parameter names
139+
Yield the pytest param for the set of template parameter
140+
"""
141+
if seed is not None:
142+
random.seed(seed)
143+
parameter_tuples = []
144+
current_module_dir = pathlib.Path(__file__).parent.absolute()
145+
template_files = (
146+
os.path.join(root, f)
147+
for root, _, files in os.walk(current_module_dir, followlinks=True)
148+
for f in files
149+
if f.endswith(".template.mlir")
150+
)
151+
for template_file in template_files:
152+
# Parse the template file
153+
with open(template_file, "r") as f:
154+
json_sting, mlir_template = f.read().split("### START TEST ###")
155+
test_spec: dict = json.loads(json_sting)
156+
mlir_template = jinja2.Template(mlir_template, undefined=jinja2.StrictUndefined)
157+
if "parameters" not in test_spec:
158+
raise ValueError(
159+
f'{template_file} does not contain a valid test specification as '
160+
'it does not specify a value for the key "parameters".'
161+
)
162+
elif "run" not in test_spec or not isinstance(test_spec["run"], str):
163+
raise ValueError(
164+
f'{template_file} does not contain a valid test specification as '
165+
'it does not specify a valid value for the key "run".'
166+
)
167+
prefilter_name = test_spec.get("prefilter")
168+
parameter_dict_filter = PREFILTER_FUNCTIONS.get(prefilter_name)
169+
if parameter_dict_filter is None:
170+
raise NameError(f"Unknown prefilter function named {repr(prefilter_name)}.")
171+
172+
# Grab test running command
173+
test_execution_command_template = test_spec["run"]
174+
175+
# Grab parameter choices
176+
parameter_choices: Dict[str, List[str]] = dict()
177+
for parameter_name, parameter_value_choices in test_spec["parameters"].items():
178+
if (
179+
isinstance(parameter_value_choices, str)
180+
and parameter_value_choices in NAMED_PARAMETER_VALUE_CHOICES
181+
):
182+
parameter_choices[parameter_name] = NAMED_PARAMETER_VALUE_CHOICES[
183+
parameter_value_choices
184+
]
185+
elif isinstance(parameter_value_choices, list):
186+
parameter_choices[parameter_name] = parameter_value_choices
187+
else:
188+
raise ValueError(
189+
f"{repr(parameter_value_choices)} does not specify a valid set of parameter values."
190+
)
191+
192+
# Handle each sampling case separately
193+
if sampling == "bucket":
194+
parameter_dicts = []
195+
for parameter_name, parameter_possible_values in parameter_choices.items():
196+
for parameter_value in parameter_possible_values:
197+
# Set up lazy iterators to randomly grab the values of all the other parameters
198+
other_parameter_names = [
199+
name
200+
for name in parameter_choices.keys()
201+
if name != parameter_name
202+
]
203+
other_parameter_choices_values = (
204+
parameter_choices[name] for name in other_parameter_names
205+
)
206+
other_parameter_choices_values = map(
207+
lazy_list_shuffler, other_parameter_choices_values
208+
)
209+
other_parameter_value_tuples = itertools.product(
210+
*other_parameter_choices_values
211+
)
212+
other_parameter_dicts = (
213+
dict(zip(other_parameter_names, other_parameter_values))
214+
for other_parameter_values in other_parameter_value_tuples
215+
)
216+
# Go through possible parameter dicts until we find a valid one
217+
for parameter_dict in other_parameter_dicts:
218+
parameter_dict[parameter_name] = parameter_value
219+
if parameter_dict_filter(parameter_dict):
220+
parameter_tuples.append(
221+
(
222+
generate_test_id_string(
223+
template_file, parameter_dict
224+
),
225+
(
226+
mlir_template.render(**parameter_dict),
227+
test_execution_command_template,
228+
template_file,
229+
parameter_dict,
230+
),
231+
)
232+
)
233+
break
234+
else:
235+
if isinstance(sampling, int):
236+
237+
def sampling_method(parameter_dicts):
238+
parameter_dicts = list(parameter_dicts)
239+
num_samples = min(sampling, len(parameter_dicts))
240+
return random.sample(parameter_dicts, num_samples)
241+
242+
elif isinstance(sampling, float):
243+
if sampling < 0 or sampling > 1:
244+
raise ValueError(
245+
f"Portion of parameter dicts to sample must be "
246+
f"in the range [0, 1], got {sampling}."
247+
)
248+
249+
def sampling_method(parameter_dicts):
250+
parameter_dicts = list(parameter_dicts)
251+
num_samples = int(len(parameter_dicts) * sampling)
252+
return random.sample(parameter_dicts, num_samples)
253+
254+
elif sampling == "exhaustive":
255+
256+
def sampling_method(parameter_dicts):
257+
return parameter_dicts
258+
259+
else:
260+
raise ValueError(
261+
f"{repr(sampling)} is not a supported sampling method."
262+
)
263+
264+
# Grab all possible parameter dicts
265+
parameter_names = parameter_choices.keys()
266+
parameter_value_tuples = itertools.product(*parameter_choices.values())
267+
all_parameter_dicts = (
268+
dict(zip(parameter_names, parameter_values))
269+
for parameter_values in parameter_value_tuples
270+
)
271+
272+
# Find the requested parameter dicts
273+
parameter_dicts = filter(parameter_dict_filter, all_parameter_dicts)
274+
parameter_dicts = sampling_method(parameter_dicts)
275+
276+
# Append one parameter dict for each test case
277+
for parameter_dict in parameter_dicts:
278+
parameter_tuples.append(
279+
(
280+
generate_test_id_string(template_file, parameter_dict),
281+
(
282+
mlir_template.render(**parameter_dict),
283+
test_execution_command_template,
284+
template_file,
285+
parameter_dict,
286+
),
287+
)
288+
)
289+
290+
return parameter_tuples
291+
292+
293+
def generate_test_id_string(template_file: str, parameter_dict: Dict[str, str]) -> str:
294+
return "".join(c for c in template_file if c.isalnum()) + "".join(
295+
f"({re.escape(key)}:{re.escape(parameter_dict[key])})"
296+
for key in sorted(parameter_dict.keys())
9297
)
10298

11-
# Ensure graphblas-opt is built
12-
subprocess.run(["python", os.path.join("mlir_graphblas", "src", "build.py")])
13299

300+
def parameterize_templatized_filecheck_test(metafunc: "_pytest.python.Metafunc"):
301+
sampling_method_string = metafunc.config.getoption("--filecheck-sampling")
302+
if sampling_method_string.isdigit():
303+
sampling = int(sampling_method_string)
304+
elif re.match("^\d+(\.\d+)?$", sampling_method_string):
305+
sampling = float(sampling_method_string)
306+
else:
307+
sampling = sampling_method_string
308+
seed = (
309+
zlib.adler32(metafunc.config.workerinput["testrunuid"].encode())
310+
if hasattr(metafunc.config, "workerinput")
311+
else time.time()
312+
)
313+
with open('/tmp/example.txt', 'a') as f:
314+
f.write(str(hasattr(metafunc.config, "workerinput")))
315+
f.write('\n')
316+
f.write(str(seed))
317+
f.write('\n')
318+
ids, parameter_values = zip(
319+
*parameter_tuples_from_templates(sampling, seed)
320+
) # TODO update the seed
321+
metafunc.parametrize(
322+
["mlir_code", "test_command_template", "template_file", "parameter_dict"],
323+
parameter_values,
324+
ids=ids,
325+
)
14326
return

dev-environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies:
1010
- coverage
1111
- pytest
1212
- pytest-cov
13+
- pytest-xdist
1314
- black
1415

1516
# documentation

mlir_graphblas/functions.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def get_mlir(self, *, make_private=True):
7979
{{ body }}
8080
8181
}
82-
"""
82+
""",
83+
undefined=jinja2.StrictUndefined,
8384
)
8485

8586
def get_mlir_module(self, make_private=False):
@@ -169,7 +170,8 @@ def get_mlir(self, *, make_private=True):
169170
170171
{% endif %}
171172
}
172-
"""
173+
""",
174+
undefined=jinja2.StrictUndefined,
173175
)
174176

175177

@@ -208,7 +210,8 @@ def get_mlir(self, *, make_private=True):
208210
%output = graphblas.matrix_select %input { selectors = ["{{ selector }}"] } : tensor<?x?xf64, #CSR64> to tensor<?x?xf64, #CSR64>
209211
return %output : tensor<?x?xf64, #CSR64>
210212
}
211-
"""
213+
""",
214+
undefined=jinja2.StrictUndefined,
212215
)
213216

214217

@@ -250,7 +253,8 @@ def get_mlir(self, *, make_private=True):
250253
251254
return %total : f64
252255
}
253-
"""
256+
""",
257+
undefined=jinja2.StrictUndefined,
254258
)
255259

256260

@@ -288,7 +292,8 @@ def get_mlir(self, *, make_private=True):
288292
289293
return %output : tensor<?x?xf64, #CSR64>
290294
}
291-
"""
295+
""",
296+
undefined=jinja2.StrictUndefined,
292297
)
293298

294299

@@ -348,5 +353,6 @@ def get_mlir(self, *, make_private=True):
348353
349354
return %output : tensor<?x?xf64, #CSR64>
350355
}
351-
"""
356+
""",
357+
undefined=jinja2.StrictUndefined,
352358
)

0 commit comments

Comments
 (0)