Skip to content

Commit 1e6d22b

Browse files
add attributes test
1 parent c6126b4 commit 1e6d22b

File tree

1 file changed

+47
-6
lines changed

1 file changed

+47
-6
lines changed

cuda_core/tests/test_memory.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -259,12 +259,6 @@ def test_mempool():
259259
dst_buffer.close()
260260
src_buffer.close()
261261

262-
# Test pool attributes
263-
used_mem = mr.used_mem_current
264-
assert used_mem >= 0
265-
reserved_mem = mr.reserved_mem_current
266-
assert reserved_mem >= 0
267-
268262
# Test error cases
269263
with pytest.raises(NotImplementedError, match="directly creating a Mempool object is not supported"):
270264
Mempool()
@@ -287,6 +281,53 @@ def test_mempool():
287281
buffer.close()
288282

289283

284+
@pytest.mark.parametrize(
285+
"property_name,expected_type",
286+
[
287+
("reuse_follow_event_dependencies", bool),
288+
("reuse_allow_opportunistic", bool),
289+
("reuse_allow_internal_dependencies", bool),
290+
("release_threshold", int),
291+
("reserved_mem_current", int),
292+
("reserved_mem_high", int),
293+
("used_mem_current", int),
294+
("used_mem_high", int),
295+
],
296+
)
297+
def test_mempool_properties(property_name, expected_type):
298+
"""Test all properties of the Mempool class."""
299+
device = Device()
300+
device.set_current()
301+
pool_size = 2097152 # 2MB size
302+
mr = Mempool.create(device.device_id, pool_size, enable_ipc=False)
303+
304+
# Get the property value
305+
value = getattr(mr, property_name)
306+
307+
# Test type
308+
assert isinstance(value, expected_type), f"{property_name} should return {expected_type}, got {type(value)}"
309+
310+
# Test value constraints
311+
if expected_type is int:
312+
assert value >= 0, f"{property_name} should be non-negative"
313+
314+
# Test memory usage properties with actual allocations
315+
if property_name in ["reserved_mem_current", "used_mem_current"]:
316+
# Allocate some memory and check if values increase
317+
initial_value = value
318+
buffer = mr.allocate(1024)
319+
new_value = getattr(mr, property_name)
320+
assert new_value >= initial_value, f"{property_name} should increase or stay same after allocation"
321+
buffer.close()
322+
323+
# Test high watermark properties
324+
if property_name in ["reserved_mem_high", "used_mem_high"]:
325+
# High watermark should never be less than current
326+
current_prop = property_name.replace("_high", "_current")
327+
current_value = getattr(mr, current_prop)
328+
assert value >= current_value, f"{property_name} should be >= {current_prop}"
329+
330+
290331
def mempool_child_process(importer, queue):
291332
try:
292333
device = Device()

0 commit comments

Comments
 (0)