Skip to content

Commit c238c1a

Browse files
committed
Add StridedLayout tests
Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
1 parent fab878e commit c238c1a

File tree

2 files changed

+974
-0
lines changed

2 files changed

+974
-0
lines changed

cuda_core/tests/helpers/layout.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import itertools
6+
from enum import Enum
7+
8+
import numpy as np
9+
10+
11+
class NamedParam:
12+
def __init__(self, name, value):
13+
self.name = name
14+
self.value = value
15+
16+
def __bool__(self):
17+
return bool(self.value)
18+
19+
def pretty_name(self):
20+
if isinstance(self.value, Enum):
21+
value_str = self.value.name
22+
else:
23+
value_str = str(self.value)
24+
return f"{self.name}.{value_str}"
25+
26+
27+
class DenseOrder(Enum):
28+
"""
29+
Whether to initialize the dense layout in C or F order.
30+
For C, the strides can be explicit or implicit (None).
31+
"""
32+
33+
C = "C"
34+
IMPLICIT_C = "implicit_c"
35+
F = "F"
36+
37+
38+
class _S:
39+
"""
40+
SliceSpec
41+
"""
42+
43+
def __init__(self):
44+
self.slices = []
45+
46+
def __getitem__(self, value):
47+
self.slices.append(value)
48+
return self
49+
50+
51+
class LayoutSpec:
52+
"""
53+
Pretty printable specification of a layout in a test case.
54+
"""
55+
56+
def __init__(
57+
self,
58+
shape,
59+
itemsize,
60+
stride_order=DenseOrder.C,
61+
perm=None,
62+
slices=None,
63+
np_ref=None,
64+
):
65+
self.shape = shape
66+
self.itemsize = itemsize
67+
self.stride_order = stride_order
68+
self.perm = perm
69+
if slices is not None:
70+
assert isinstance(slices, _S)
71+
slices = slices.slices
72+
self.slices = slices
73+
self.np_ref = np_ref
74+
75+
def pretty_name(self):
76+
desc = [
77+
f"ndim.{len(self.shape)}",
78+
f"shape.{self.shape}",
79+
f"itemsize.{self.itemsize}",
80+
]
81+
if self.stride_order is not None:
82+
if isinstance(self.stride_order, DenseOrder):
83+
desc.append(f"stride_order.{self.stride_order.value}")
84+
else:
85+
assert isinstance(self.stride_order, tuple)
86+
assert len(self.stride_order) == len(self.shape)
87+
desc.append(f"stride_order.{self.stride_order}")
88+
if self.perm is not None:
89+
desc.append(f"perm.{self.perm}")
90+
if self.slices is not None:
91+
desc.append(f"slices.{self.slices}")
92+
return "-".join(desc)
93+
94+
def dtype_from_itemsize(self):
95+
return dtype_from_itemsize(self.itemsize)
96+
97+
def np_order(self):
98+
return "F" if self.stride_order == DenseOrder.F else "C"
99+
100+
def has_no_strides(self):
101+
return self.stride_order == DenseOrder.IMPLICIT_C
102+
103+
def has_no_strides_transformed(self):
104+
return self.stride_order == DenseOrder.IMPLICIT_C and self.perm is None and self.slices is None
105+
106+
107+
def dtype_from_itemsize(itemsize):
108+
if itemsize <= 8:
109+
return np.dtype(f"int{itemsize * 8}")
110+
elif itemsize == 16:
111+
return np.dtype("complex128")
112+
else:
113+
raise ValueError(f"Unsupported itemsize: {itemsize}")
114+
115+
116+
def pretty_name(val):
117+
"""
118+
Pytest does not pretty print (repr/str) parameters of custom types.
119+
Use this function as the `ids` argument of `pytest.mark.parametrize`, e.g.:
120+
``@pytest.mark.parametrize(..., ids=pretty_name)``
121+
"""
122+
if hasattr(val, "pretty_name"):
123+
return val.pretty_name()
124+
# use default pytest pretty printing
125+
return None
126+
127+
128+
def flatten_mask2str(mask, ndim):
129+
return "".join("1" if mask & (1 << i) else "0" for i in range(ndim))
130+
131+
132+
def random_permutations(rng, perm_len, cutoff_len=3, sample_size=6):
133+
if perm_len <= cutoff_len:
134+
return [perm for perm in itertools.permutations(range(perm_len))]
135+
perms = []
136+
for _ in range(sample_size):
137+
perm = list(range(perm_len))
138+
rng.shuffle(perm)
139+
perms.append(tuple(perm))
140+
return perms
141+
142+
143+
def inv_permutation(perm):
144+
inv = [None] * len(perm)
145+
for i, p in enumerate(perm):
146+
inv[p] = i
147+
return tuple(inv)
148+
149+
150+
def permuted(t, perm):
151+
return tuple(t[i] for i in perm)

0 commit comments

Comments
 (0)