Skip to content

Commit edb8fa9

Browse files
add proper cleanup to test files
1 parent 1e6d22b commit edb8fa9

File tree

1 file changed

+111
-79
lines changed

1 file changed

+111
-79
lines changed

cuda_core/tests/test_memory.py

Lines changed: 111 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -301,31 +301,41 @@ def test_mempool_properties(property_name, expected_type):
301301
pool_size = 2097152 # 2MB size
302302
mr = Mempool.create(device.device_id, pool_size, enable_ipc=False)
303303

304-
# Get the property value
305-
value = getattr(mr, property_name)
306-
307-
# Test type
308-
assert isinstance(value, expected_type), f"{property_name} should return {expected_type}, got {type(value)}"
309-
310-
# Test value constraints
311-
if expected_type is int:
312-
assert value >= 0, f"{property_name} should be non-negative"
313-
314-
# Test memory usage properties with actual allocations
315-
if property_name in ["reserved_mem_current", "used_mem_current"]:
316-
# Allocate some memory and check if values increase
317-
initial_value = value
318-
buffer = mr.allocate(1024)
319-
new_value = getattr(mr, property_name)
320-
assert new_value >= initial_value, f"{property_name} should increase or stay same after allocation"
321-
buffer.close()
322-
323-
# Test high watermark properties
324-
if property_name in ["reserved_mem_high", "used_mem_high"]:
325-
# High watermark should never be less than current
326-
current_prop = property_name.replace("_high", "_current")
327-
current_value = getattr(mr, current_prop)
328-
assert value >= current_value, f"{property_name} should be >= {current_prop}"
304+
try:
305+
# Get the property value
306+
value = getattr(mr, property_name)
307+
308+
# Test type
309+
assert isinstance(value, expected_type), f"{property_name} should return {expected_type}, got {type(value)}"
310+
311+
# Test value constraints
312+
if expected_type is int:
313+
assert value >= 0, f"{property_name} should be non-negative"
314+
315+
# Test memory usage properties with actual allocations
316+
if property_name in ["reserved_mem_current", "used_mem_current"]:
317+
# Allocate some memory and check if values increase
318+
initial_value = value
319+
buffer = None
320+
try:
321+
buffer = mr.allocate(1024)
322+
new_value = getattr(mr, property_name)
323+
assert new_value >= initial_value, f"{property_name} should increase or stay same after allocation"
324+
finally:
325+
if buffer is not None:
326+
buffer.close()
327+
328+
# Test high watermark properties
329+
if property_name in ["reserved_mem_high", "used_mem_high"]:
330+
# High watermark should never be less than current
331+
current_prop = property_name.replace("_high", "_current")
332+
current_value = getattr(mr, current_prop)
333+
assert value >= current_value, f"{property_name} should be >= {current_prop}"
334+
335+
finally:
336+
# Ensure we allocate and deallocate a small buffer to flush any pending operations
337+
flush_buffer = mr.allocate(64)
338+
flush_buffer.close()
329339

330340

331341
def mempool_child_process(importer, queue):
@@ -387,62 +397,84 @@ def test_ipc_mempool():
387397
stream = device.create_stream()
388398
pool_size = 2097152 # 2MB size
389399
mr = Mempool.create(device.device_id, pool_size, enable_ipc=True)
390-
shareable_handle = mr.get_shareable_handle()
391-
392-
# Allocate and export memory
393-
buffer = mr.allocate(64)
394-
395-
# Fill buffer with test pattern using unified memory
396-
unified_mr = DummyUnifiedMemoryResource(device)
397-
src_buffer = unified_mr.allocate(64)
398-
src_ptr = ctypes.cast(int(src_buffer.handle), ctypes.POINTER(ctypes.c_byte))
399-
for i in range(64):
400-
src_ptr[i] = ctypes.c_byte(i)
401-
402-
buffer.copy_from(src_buffer, stream=stream)
403-
device.sync()
404-
src_buffer.close()
405-
406-
# Export buffer for IPC
407-
ipc_buffer = mr.export_buffer(buffer)
408400

409401
# Create socket pair for handle transfer
410402
exporter, importer = socketpair(AF_UNIX, SOCK_DGRAM)
411-
412-
# Start child process
413-
multiprocessing.set_start_method("spawn", force=True)
414403
queue = multiprocessing.Queue()
415-
process = multiprocessing.Process(target=mempool_child_process, args=(importer, queue))
416-
process.start()
417-
418-
# Send handles to child process
419-
exporter.sendmsg([], [(SOL_SOCKET, SCM_RIGHTS, array.array("i", [shareable_handle]))])
420-
queue.put(ipc_buffer)
421-
422-
# Wait for child process
423-
process.join(timeout=10)
424-
assert process.exitcode == 0
425-
426-
# Check for exceptions
427-
if not queue.empty():
428-
result = queue.get()
429-
if isinstance(result, tuple):
430-
exception, traceback_str = result
431-
print("\nException in child process:")
432-
print(traceback_str)
433-
raise exception
434-
assert result is True
435-
436-
# Verify child process wrote the inverted pattern using unified memory
437-
verify_buffer = unified_mr.allocate(64)
438-
verify_buffer.copy_from(buffer, stream=stream)
439-
device.sync()
404+
process = None
440405

441-
verify_ptr = ctypes.cast(int(verify_buffer.handle), ctypes.POINTER(ctypes.c_byte))
442-
for i in range(64):
443-
assert (
444-
ctypes.c_byte(verify_ptr[i]).value == ctypes.c_byte(255 - i).value
445-
), f"Child process data not reflected in parent at index {i}"
446-
447-
verify_buffer.close()
448-
buffer.close()
406+
try:
407+
shareable_handle = mr.get_shareable_handle()
408+
409+
# Allocate and export memory
410+
buffer = mr.allocate(64)
411+
412+
try:
413+
# Fill buffer with test pattern using unified memory
414+
unified_mr = DummyUnifiedMemoryResource(device)
415+
src_buffer = unified_mr.allocate(64)
416+
try:
417+
src_ptr = ctypes.cast(int(src_buffer.handle), ctypes.POINTER(ctypes.c_byte))
418+
for i in range(64):
419+
src_ptr[i] = ctypes.c_byte(i)
420+
421+
buffer.copy_from(src_buffer, stream=stream)
422+
device.sync()
423+
finally:
424+
src_buffer.close()
425+
426+
# Export buffer for IPC
427+
ipc_buffer = mr.export_buffer(buffer)
428+
429+
# Start child process
430+
multiprocessing.set_start_method("spawn", force=True)
431+
process = multiprocessing.Process(target=mempool_child_process, args=(importer, queue))
432+
process.start()
433+
434+
# Send handles to child process
435+
exporter.sendmsg([], [(SOL_SOCKET, SCM_RIGHTS, array.array("i", [shareable_handle]))])
436+
queue.put(ipc_buffer)
437+
438+
# Wait for child process
439+
process.join(timeout=10)
440+
assert process.exitcode == 0
441+
442+
# Check for exceptions
443+
if not queue.empty():
444+
result = queue.get()
445+
if isinstance(result, tuple):
446+
exception, traceback_str = result
447+
print("\nException in child process:")
448+
print(traceback_str)
449+
raise exception
450+
assert result is True
451+
452+
# Verify child process wrote the inverted pattern using unified memory
453+
verify_buffer = unified_mr.allocate(64)
454+
try:
455+
verify_buffer.copy_from(buffer, stream=stream)
456+
device.sync()
457+
458+
verify_ptr = ctypes.cast(int(verify_buffer.handle), ctypes.POINTER(ctypes.c_byte))
459+
for i in range(64):
460+
assert (
461+
ctypes.c_byte(verify_ptr[i]).value == ctypes.c_byte(255 - i).value
462+
), f"Child process data not reflected in parent at index {i}"
463+
finally:
464+
verify_buffer.close()
465+
466+
finally:
467+
buffer.close()
468+
469+
finally:
470+
# Clean up all resources
471+
if process is not None and process.is_alive():
472+
process.terminate()
473+
process.join(timeout=1)
474+
queue.close()
475+
queue.join_thread() # Ensure the queue's background thread is cleaned up
476+
exporter.close()
477+
importer.close()
478+
# Flush any pending operations
479+
flush_buffer = mr.allocate(64)
480+
flush_buffer.close()

0 commit comments

Comments
 (0)