|
1 | 1 | import os |
2 | 2 | import distutils.core |
3 | 3 | import subprocess |
| 4 | +import json |
| 5 | +import jinja2 |
| 6 | +import itertools |
| 7 | +import random |
| 8 | +import re |
| 9 | +import pathlib |
| 10 | +import time |
| 11 | +import zlib |
4 | 12 |
|
| 13 | +from typing import Any, Dict, List, Tuple, Generator, Callable, Literal, Union, Optional |
5 | 14 |
|
6 | | -def pytest_configure(config): |
7 | | - distutils.core.run_setup( |
8 | | - "./setup.py", script_args=["build_ext", "--inplace"], stop_after="run" |
| 15 | +########## |
| 16 | +# Config # |
| 17 | +########## |
| 18 | + |
| 19 | + |
| 20 | +def pytest_configure(config: "_pytest.config.Config"): |
| 21 | + if not hasattr(config, "workerinput"): # only run once on main process |
| 22 | + distutils.core.run_setup( |
| 23 | + "./setup.py", script_args=["build_ext", "--inplace"], stop_after="run" |
| 24 | + ) |
| 25 | + # Ensure graphblas-opt is built |
| 26 | + subprocess.run(["python", os.path.join("mlir_graphblas", "src", "build.py")]) |
| 27 | + |
| 28 | + return |
| 29 | + |
| 30 | + |
| 31 | +def pytest_addoption(parser: "_pytest.config.argparsing.Parser"): |
| 32 | + parser.addoption( |
| 33 | + "--filecheck-sampling", |
| 34 | + action="store", |
| 35 | + default="default", |
| 36 | + help="Method to sample the space of the templatized FileCheck tests.", |
| 37 | + ) |
| 38 | + return |
| 39 | + |
| 40 | + |
| 41 | +def pytest_generate_tests(metafunc: "_pytest.python.Metafunc"): |
| 42 | + if metafunc.function.__name__ == "test_filecheck_mlir": |
| 43 | + parameterize_templatized_filecheck_test(metafunc) |
| 44 | + return |
| 45 | + |
| 46 | + |
| 47 | +########################################### |
| 48 | +# Parameterization of test_filecheck_mlir # |
| 49 | +########################################### |
| 50 | + |
| 51 | +NAMED_PARAMETER_VALUE_CHOICES = { |
| 52 | + "STANDARD_ELEMENT_TYPES": [ |
| 53 | + "i1", |
| 54 | + "i4", |
| 55 | + "i8", |
| 56 | + "i16", |
| 57 | + "i32", |
| 58 | + "i64", |
| 59 | + "f16", |
| 60 | + "bf16", |
| 61 | + "f32", |
| 62 | + "f64", |
| 63 | + "f80", |
| 64 | + "f128", |
| 65 | + ], |
| 66 | + "MATRIX_APPLY_OPERATORS": ["min"], |
| 67 | + "SPARSITY_TYPES": ["dense", "compressed", "singleton"], |
| 68 | + "MATRIX_MULTIPLY_SEMIRINGS": ["plus_times", "plus_pair", "plus_plus"], |
| 69 | +} |
| 70 | + |
| 71 | +PREFILTER_FUNCTIONS = {None: lambda *args: True} |
| 72 | + |
| 73 | +# Prefilter Functions |
| 74 | + |
| 75 | + |
| 76 | +def prefilter_func( |
| 77 | + func: Callable[[Dict[str, str]], bool] |
| 78 | +) -> Callable[[Dict[str, str]], bool]: |
| 79 | + if func.__name__ in PREFILTER_FUNCTIONS: |
| 80 | + raise Exception( |
| 81 | + f"{repr(func.__name__)} is already in use as prefilter function name." |
| 82 | + ) |
| 83 | + PREFILTER_FUNCTIONS[func.__name__] = func |
| 84 | + return func |
| 85 | + |
| 86 | + |
| 87 | +@prefilter_func |
| 88 | +def different_thunk_and_element_type(parameter_dict: Dict[str, str]) -> bool: |
| 89 | + return parameter_dict["element_type"] != parameter_dict["thunk_type"] |
| 90 | + |
| 91 | + |
| 92 | +@prefilter_func |
| 93 | +def sparse_but_not_compressed_sparse(parameter_dict: Dict[str, str]) -> bool: |
| 94 | + sparsity0 = parameter_dict["sparsity0"] |
| 95 | + sparsity1 = parameter_dict["sparsity1"] |
| 96 | + return "compressed" in (sparsity0, sparsity1) and ( |
| 97 | + (sparsity0, sparsity1) != ("dense", "compressed") |
| 98 | + ) |
| 99 | + |
| 100 | + |
| 101 | +# Template Expansion |
| 102 | + |
| 103 | + |
| 104 | +def lazy_list_shuffler(input_list: list) -> Generator[Any, None, None]: |
| 105 | + """ |
| 106 | + This generator yields an unseen item at random from the input_list. |
| 107 | + Upon each call to the generator, it'll randomly select an index |
| 108 | + until an unseen index is found. |
| 109 | +
|
| 110 | + Running this generator to exhaustion is approximately O(n**2). |
| 111 | + This should only be used in cases where the generator is NOT expected |
| 112 | + to be exhausted and is expected to run a few times, in |
| 113 | + which case the expected running time is Theta(1). |
| 114 | + """ |
| 115 | + seen_indices = set() |
| 116 | + for _ in range(len(input_list)): |
| 117 | + while (index := random.randint(0, len(input_list) - 1)) in seen_indices: |
| 118 | + pass |
| 119 | + seen_indices.add(index) |
| 120 | + yield input_list[index] |
| 121 | + return |
| 122 | + |
| 123 | + |
| 124 | +def parameter_tuples_from_templates( |
| 125 | + sampling: Union[int, float, Literal["exhaustive", "bucket"]], |
| 126 | + seed: Optional[int] = None, |
| 127 | +) -> List[Tuple[str, Tuple[str, str, Dict[str, str]]]]: |
| 128 | + """ |
| 129 | + Returns a list of tuples of the form ( |
| 130 | + "mlir_code", |
| 131 | + "test command jinja template", |
| 132 | + "/template/file/location", |
| 133 | + {"template_parameter_0": "template_parameter_value_0", ...} |
| 134 | + ) |
| 135 | +
|
| 136 | + If sampling is a float (in the range [0, 1]), the number of parameter |
| 137 | + tuples returned will be (total_num_possible_cases * sampling) for each |
| 138 | + template. |
| 139 | +
|
| 140 | + If sampling is an int, the number of parameter tuples returned will |
| 141 | + be (sampling) for each template. |
| 142 | +
|
| 143 | + If sampling is "exhaustive", we will return all possible parameter |
| 144 | + tuples for each template. |
| 145 | +
|
| 146 | + If sampling is "bucket", we will yield parameter tuples as follows: |
| 147 | + For each template parameter name: |
| 148 | + Randomly sample one template parameter value for all other |
| 149 | + template parameter names |
| 150 | + Yield the pytest param for the set of template parameter |
| 151 | + """ |
| 152 | + if seed is not None: |
| 153 | + random.seed(seed) |
| 154 | + parameter_tuples = [] |
| 155 | + current_module_dir = pathlib.Path(__file__).parent.absolute() |
| 156 | + template_files = ( |
| 157 | + os.path.join(root, f) |
| 158 | + for root, _, files in os.walk(current_module_dir, followlinks=True) |
| 159 | + for f in files |
| 160 | + if f.endswith(".template.mlir") |
| 161 | + ) |
| 162 | + for template_file in template_files: |
| 163 | + # Parse the template file |
| 164 | + with open(template_file, "r") as f: |
| 165 | + json_sting, mlir_template = f.read().split("### START TEST ###") |
| 166 | + test_spec: dict = json.loads(json_sting) |
| 167 | + mlir_template = jinja2.Template(mlir_template, undefined=jinja2.StrictUndefined) |
| 168 | + if "parameters" not in test_spec: |
| 169 | + raise ValueError( |
| 170 | + f"{template_file} does not contain a valid test specification as " |
| 171 | + 'it does not specify a value for the key "parameters".' |
| 172 | + ) |
| 173 | + elif "run" not in test_spec or not isinstance(test_spec["run"], str): |
| 174 | + raise ValueError( |
| 175 | + f"{template_file} does not contain a valid test specification as " |
| 176 | + 'it does not specify a valid value for the key "run".' |
| 177 | + ) |
| 178 | + prefilter_name = test_spec.get("prefilter") |
| 179 | + parameter_dict_filter = PREFILTER_FUNCTIONS.get(prefilter_name) |
| 180 | + if parameter_dict_filter is None: |
| 181 | + raise NameError(f"Unknown prefilter function named {repr(prefilter_name)}.") |
| 182 | + |
| 183 | + # Grab test running command |
| 184 | + test_execution_command_template = test_spec["run"] |
| 185 | + |
| 186 | + # Grab parameter choices |
| 187 | + parameter_choices: Dict[str, List[str]] = dict() |
| 188 | + for parameter_name, parameter_value_choices in test_spec["parameters"].items(): |
| 189 | + if ( |
| 190 | + isinstance(parameter_value_choices, str) |
| 191 | + and parameter_value_choices in NAMED_PARAMETER_VALUE_CHOICES |
| 192 | + ): |
| 193 | + parameter_choices[parameter_name] = NAMED_PARAMETER_VALUE_CHOICES[ |
| 194 | + parameter_value_choices |
| 195 | + ] |
| 196 | + elif isinstance(parameter_value_choices, list): |
| 197 | + parameter_choices[parameter_name] = parameter_value_choices |
| 198 | + else: |
| 199 | + raise ValueError( |
| 200 | + f"{repr(parameter_value_choices)} does not specify a valid set of parameter values." |
| 201 | + ) |
| 202 | + |
| 203 | + # Handle each sampling case separately |
| 204 | + if sampling == "bucket": |
| 205 | + parameter_dicts = [] |
| 206 | + for parameter_name, parameter_possible_values in parameter_choices.items(): |
| 207 | + for parameter_value in parameter_possible_values: |
| 208 | + # Set up lazy iterators to randomly grab the values of all the other parameters |
| 209 | + other_parameter_names = [ |
| 210 | + name |
| 211 | + for name in parameter_choices.keys() |
| 212 | + if name != parameter_name |
| 213 | + ] |
| 214 | + other_parameter_choices_values = ( |
| 215 | + parameter_choices[name] for name in other_parameter_names |
| 216 | + ) |
| 217 | + other_parameter_choices_values = map( |
| 218 | + lazy_list_shuffler, other_parameter_choices_values |
| 219 | + ) |
| 220 | + other_parameter_value_tuples = itertools.product( |
| 221 | + *other_parameter_choices_values |
| 222 | + ) |
| 223 | + other_parameter_dicts = ( |
| 224 | + dict(zip(other_parameter_names, other_parameter_values)) |
| 225 | + for other_parameter_values in other_parameter_value_tuples |
| 226 | + ) |
| 227 | + # Go through possible parameter dicts until we find a valid one |
| 228 | + for parameter_dict in other_parameter_dicts: |
| 229 | + parameter_dict[parameter_name] = parameter_value |
| 230 | + if parameter_dict_filter(parameter_dict): |
| 231 | + parameter_tuples.append( |
| 232 | + ( |
| 233 | + generate_test_id_string( |
| 234 | + template_file, parameter_dict |
| 235 | + ), |
| 236 | + ( |
| 237 | + mlir_template.render(**parameter_dict), |
| 238 | + test_execution_command_template, |
| 239 | + template_file, |
| 240 | + parameter_dict, |
| 241 | + ), |
| 242 | + ) |
| 243 | + ) |
| 244 | + break |
| 245 | + else: |
| 246 | + if isinstance(sampling, int): |
| 247 | + |
| 248 | + def sampling_method(parameter_dicts): |
| 249 | + parameter_dicts = list(parameter_dicts) |
| 250 | + num_samples = min(sampling, len(parameter_dicts)) |
| 251 | + return random.sample(parameter_dicts, num_samples) |
| 252 | + |
| 253 | + elif isinstance(sampling, float): |
| 254 | + if sampling < 0 or sampling > 1: |
| 255 | + raise ValueError( |
| 256 | + f"Portion of parameter dicts to sample must be " |
| 257 | + f"in the range [0, 1], got {sampling}." |
| 258 | + ) |
| 259 | + |
| 260 | + def sampling_method(parameter_dicts): |
| 261 | + parameter_dicts = list(parameter_dicts) |
| 262 | + num_samples = int(len(parameter_dicts) * sampling) |
| 263 | + return random.sample(parameter_dicts, num_samples) |
| 264 | + |
| 265 | + elif sampling == "default": |
| 266 | + |
| 267 | + def sampling_method(parameter_dicts): |
| 268 | + for parameter_dict in parameter_dicts: |
| 269 | + return [parameter_dict] |
| 270 | + return [] |
| 271 | + |
| 272 | + elif sampling == "exhaustive": |
| 273 | + |
| 274 | + def sampling_method(parameter_dicts): |
| 275 | + return parameter_dicts |
| 276 | + |
| 277 | + else: |
| 278 | + raise ValueError( |
| 279 | + f"{repr(sampling)} is not a supported sampling method." |
| 280 | + ) |
| 281 | + |
| 282 | + # Grab all possible parameter dicts |
| 283 | + parameter_names = parameter_choices.keys() |
| 284 | + parameter_value_tuples = itertools.product(*parameter_choices.values()) |
| 285 | + all_parameter_dicts = ( |
| 286 | + dict(zip(parameter_names, parameter_values)) |
| 287 | + for parameter_values in parameter_value_tuples |
| 288 | + ) |
| 289 | + |
| 290 | + # Find the requested parameter dicts |
| 291 | + parameter_dicts = filter(parameter_dict_filter, all_parameter_dicts) |
| 292 | + parameter_dicts = sampling_method(parameter_dicts) |
| 293 | + |
| 294 | + # Append one parameter dict for each test case |
| 295 | + for parameter_dict in parameter_dicts: |
| 296 | + parameter_tuples.append( |
| 297 | + ( |
| 298 | + generate_test_id_string(template_file, parameter_dict), |
| 299 | + ( |
| 300 | + mlir_template.render(**parameter_dict), |
| 301 | + test_execution_command_template, |
| 302 | + template_file, |
| 303 | + parameter_dict, |
| 304 | + ), |
| 305 | + ) |
| 306 | + ) |
| 307 | + |
| 308 | + return parameter_tuples |
| 309 | + |
| 310 | + |
| 311 | +def generate_test_id_string(template_file: str, parameter_dict: Dict[str, str]) -> str: |
| 312 | + return "".join(c for c in template_file if c.isalnum()) + "".join( |
| 313 | + f"({re.escape(key)}:{re.escape(parameter_dict[key])})" |
| 314 | + for key in sorted(parameter_dict.keys()) |
| 315 | + ) |
| 316 | + |
| 317 | + |
| 318 | +def parameterize_templatized_filecheck_test(metafunc: "_pytest.python.Metafunc"): |
| 319 | + sampling_method_string = metafunc.config.getoption("--filecheck-sampling") |
| 320 | + if sampling_method_string.isdigit(): |
| 321 | + sampling = int(sampling_method_string) |
| 322 | + elif re.match("^\d+(\.\d+)?$", sampling_method_string): |
| 323 | + sampling = float(sampling_method_string) |
| 324 | + else: |
| 325 | + sampling = sampling_method_string |
| 326 | + seed = ( |
| 327 | + zlib.adler32(metafunc.config.workerinput["testrunuid"].encode()) |
| 328 | + if hasattr(metafunc.config, "workerinput") |
| 329 | + else time.time() |
| 330 | + ) |
| 331 | + ids, parameter_values = zip(*parameter_tuples_from_templates(sampling, seed)) |
| 332 | + metafunc.parametrize( |
| 333 | + ["mlir_code", "test_command_template", "template_file", "parameter_dict"], |
| 334 | + parameter_values, |
| 335 | + ids=ids, |
9 | 336 | ) |
10 | 337 |
|
11 | 338 | # Ensure graphblas-opt is built |
|
0 commit comments