|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +import itertools |
| 6 | +from enum import Enum |
| 7 | + |
| 8 | +import numpy as np |
| 9 | + |
| 10 | + |
| 11 | +class NamedParam: |
| 12 | + def __init__(self, name, value): |
| 13 | + self.name = name |
| 14 | + self.value = value |
| 15 | + |
| 16 | + def __bool__(self): |
| 17 | + return bool(self.value) |
| 18 | + |
| 19 | + def pretty_name(self): |
| 20 | + if isinstance(self.value, Enum): |
| 21 | + value_str = self.value.name |
| 22 | + else: |
| 23 | + value_str = str(self.value) |
| 24 | + return f"{self.name}.{value_str}" |
| 25 | + |
| 26 | + |
| 27 | +class DenseOrder(Enum): |
| 28 | + """ |
| 29 | + Whether to initialize the dense layout in C or F order. |
| 30 | + For C, the strides can be explicit or implicit (None). |
| 31 | + """ |
| 32 | + |
| 33 | + C = "C" |
| 34 | + IMPLICIT_C = "implicit_c" |
| 35 | + F = "F" |
| 36 | + |
| 37 | + |
| 38 | +class _S: |
| 39 | + """ |
| 40 | + SliceSpec |
| 41 | + """ |
| 42 | + |
| 43 | + def __init__(self): |
| 44 | + self.slices = [] |
| 45 | + |
| 46 | + def __getitem__(self, value): |
| 47 | + self.slices.append(value) |
| 48 | + return self |
| 49 | + |
| 50 | + |
| 51 | +class LayoutSpec: |
| 52 | + """ |
| 53 | + Pretty printable specification of a layout in a test case. |
| 54 | + """ |
| 55 | + |
| 56 | + def __init__( |
| 57 | + self, |
| 58 | + shape, |
| 59 | + itemsize, |
| 60 | + stride_order=DenseOrder.C, |
| 61 | + perm=None, |
| 62 | + slices=None, |
| 63 | + np_ref=None, |
| 64 | + ): |
| 65 | + self.shape = shape |
| 66 | + self.itemsize = itemsize |
| 67 | + self.stride_order = stride_order |
| 68 | + self.perm = perm |
| 69 | + if slices is not None: |
| 70 | + assert isinstance(slices, _S) |
| 71 | + slices = slices.slices |
| 72 | + self.slices = slices |
| 73 | + self.np_ref = np_ref |
| 74 | + |
| 75 | + def pretty_name(self): |
| 76 | + desc = [ |
| 77 | + f"ndim.{len(self.shape)}", |
| 78 | + f"shape.{self.shape}", |
| 79 | + f"itemsize.{self.itemsize}", |
| 80 | + ] |
| 81 | + if self.stride_order is not None: |
| 82 | + if isinstance(self.stride_order, DenseOrder): |
| 83 | + desc.append(f"stride_order.{self.stride_order.value}") |
| 84 | + else: |
| 85 | + assert isinstance(self.stride_order, tuple) |
| 86 | + assert len(self.stride_order) == len(self.shape) |
| 87 | + desc.append(f"stride_order.{self.stride_order}") |
| 88 | + if self.perm is not None: |
| 89 | + desc.append(f"perm.{self.perm}") |
| 90 | + if self.slices is not None: |
| 91 | + desc.append(f"slices.{self.slices}") |
| 92 | + return "-".join(desc) |
| 93 | + |
| 94 | + def dtype_from_itemsize(self): |
| 95 | + return dtype_from_itemsize(self.itemsize) |
| 96 | + |
| 97 | + def np_order(self): |
| 98 | + return "F" if self.stride_order == DenseOrder.F else "C" |
| 99 | + |
| 100 | + def has_no_strides(self): |
| 101 | + return self.stride_order == DenseOrder.IMPLICIT_C |
| 102 | + |
| 103 | + def has_no_strides_transformed(self): |
| 104 | + return self.stride_order == DenseOrder.IMPLICIT_C and self.perm is None and self.slices is None |
| 105 | + |
| 106 | + |
| 107 | +def dtype_from_itemsize(itemsize): |
| 108 | + if itemsize <= 8: |
| 109 | + return np.dtype(f"int{itemsize * 8}") |
| 110 | + elif itemsize == 16: |
| 111 | + return np.dtype("complex128") |
| 112 | + else: |
| 113 | + raise ValueError(f"Unsupported itemsize: {itemsize}") |
| 114 | + |
| 115 | + |
| 116 | +def pretty_name(val): |
| 117 | + """ |
| 118 | + Pytest does not pretty print (repr/str) parameters of custom types. |
| 119 | + Use this function as the `ids` argument of `pytest.mark.parametrize`, e.g.: |
| 120 | + ``@pytest.mark.parametrize(..., ids=pretty_name)`` |
| 121 | + """ |
| 122 | + if hasattr(val, "pretty_name"): |
| 123 | + return val.pretty_name() |
| 124 | + # use default pytest pretty printing |
| 125 | + return None |
| 126 | + |
| 127 | + |
| 128 | +def flatten_mask2str(mask, ndim): |
| 129 | + return "".join("1" if mask & (1 << i) else "0" for i in range(ndim)) |
| 130 | + |
| 131 | + |
| 132 | +def random_permutations(rng, perm_len, cutoff_len=3, sample_size=6): |
| 133 | + if perm_len <= cutoff_len: |
| 134 | + return [perm for perm in itertools.permutations(range(perm_len))] |
| 135 | + perms = [] |
| 136 | + for _ in range(sample_size): |
| 137 | + perm = list(range(perm_len)) |
| 138 | + rng.shuffle(perm) |
| 139 | + perms.append(tuple(perm)) |
| 140 | + return perms |
| 141 | + |
| 142 | + |
| 143 | +def inv_permutation(perm): |
| 144 | + inv = [None] * len(perm) |
| 145 | + for i, p in enumerate(perm): |
| 146 | + inv[p] = i |
| 147 | + return tuple(inv) |
| 148 | + |
| 149 | + |
| 150 | +def permuted(t, perm): |
| 151 | + return tuple(t[i] for i in perm) |
0 commit comments