Skip to content

Commit be59702

Browse files
committed
add resource management to mempool
1 parent 649b204 commit be59702

File tree

1 file changed

+68
-16
lines changed

1 file changed

+68
-16
lines changed

cuda_core/cuda/core/experimental/_memory.py

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def __release_buffer__(self, buffer: memoryview, /):
213213
class IPCBuffer(Buffer):
214214
"""Buffer class to represent a buffer description which can be shared across processes.
215215
It is not a valid buffer containing data, but rather a description used by the importing
216-
process to construct a valid buffer."""
216+
process to construct a valid buffer. It's primary use is to provide a serialization
217+
mechanism for passing exported buffers between processes."""
217218

218219
def __init__(self, reserved: bytes, size):
219220
super().__init__(0, 0)
@@ -248,10 +249,36 @@ def _reconstruct(cls, reserved, size):
248249

249250

250251
class MemoryResource(abc.ABC):
251-
__slots__ = ("_handle",)
252+
"""Base class for memory resources.
253+
254+
This class provides an abstract interface for memory resources and includes
255+
an optional destruction mechanism through _MembersNeededForFinalize.
256+
"""
257+
258+
class _MembersNeededForFinalize:
259+
__slots__ = ("handle", "is_initialized")
260+
261+
def __init__(self, mr_obj, handle=None):
262+
self.handle = handle
263+
self.is_initialized = False
264+
weakref.finalize(mr_obj, self.close)
265+
266+
def close(self):
267+
if self.is_initialized and self.handle:
268+
# Specific cleanup can be implemented by derived classes
269+
self.handle = None
270+
self.is_initialized = False
271+
272+
__slots__ = ("__weakref__", "_mnff")
252273

253274
@abc.abstractmethod
254-
def __init__(self, *args, **kwargs): ...
275+
def __init__(self, *args, **kwargs):
276+
self._mnff = MemoryResource._MembersNeededForFinalize(self)
277+
278+
def close(self):
279+
"""Release any resources associated with this memory resource."""
280+
if hasattr(self, "_mnff"):
281+
self._mnff.close()
255282

256283
@abc.abstractmethod
257284
def allocate(self, size, stream=None) -> Buffer: ...
@@ -327,7 +354,22 @@ class Mempool(MemoryResource):
327354
from_shared_handle : Import an existing memory pool from another process
328355
"""
329356

330-
__slots__ = ("_dev_id", "_handle", "_ipc_enabled")
357+
class _MembersNeededForFinalize:
358+
__slots__ = ("handle", "is_initialized")
359+
360+
def __init__(self, mr_obj, handle=None):
361+
self.handle = handle
362+
weakref.finalize(mr_obj, self.close)
363+
364+
def close(self):
365+
if self.is_initialized and self.handle:
366+
handle_return(driver.cuMemPoolDestroy(self.handle))
367+
self.handle = None
368+
369+
__slots__ = (
370+
"__weakref__",
371+
"_mnff",
372+
)
331373

332374
def __init__(self):
333375
"""Direct instantiation is not supported.
@@ -360,8 +402,8 @@ def _init(dev_id: int, handle: int, ipc_enabled: bool) -> Mempool:
360402
"""
361403
self = Mempool.__new__(Mempool)
362404
self._dev_id = dev_id
363-
self._handle = handle
364405
self._ipc_enabled = ipc_enabled
406+
self._mnff = Mempool._MembersNeededForFinalize(self, handle)
365407
return self
366408

367409
@staticmethod
@@ -463,7 +505,7 @@ def get_shareable_handle(self) -> int:
463505
"""
464506
if not self._ipc_enabled:
465507
raise RuntimeError("This memory pool was not created with IPC support enabled")
466-
return handle_return(driver.cuMemPoolExportToShareableHandle(self._handle, _get_platform_handle_type(), 0))
508+
return handle_return(driver.cuMemPoolExportToShareableHandle(self._mnff.handle, _get_platform_handle_type(), 0))
467509

468510
def export_buffer(self, buffer: Buffer) -> IPCBuffer:
469511
"""Export a buffer allocated from this pool for sharing between processes.
@@ -528,7 +570,9 @@ def import_buffer(self, ipc_buffer: IPCBuffer) -> Buffer:
528570
raise RuntimeError("This memory pool was not created with IPC support enabled")
529571
share_data = driver.CUmemPoolPtrExportData()
530572
share_data.reserved = ipc_buffer.reserved
531-
return Buffer(handle_return(driver.cuMemPoolImportPointer(self._handle, share_data)), ipc_buffer._size, self)
573+
return Buffer(
574+
handle_return(driver.cuMemPoolImportPointer(self._mnff.handle, share_data)), ipc_buffer._size, self
575+
)
532576

533577
def allocate(self, size: int, stream=None) -> Buffer:
534578
"""Allocate memory from the pool.
@@ -552,7 +596,7 @@ def allocate(self, size: int, stream=None) -> Buffer:
552596
"""
553597
if stream is None:
554598
stream = default_stream()
555-
ptr = handle_return(driver.cuMemAllocFromPoolAsync(size, self._handle, stream.handle))
599+
ptr = handle_return(driver.cuMemAllocFromPoolAsync(size, self._mnff.handle, stream.handle))
556600
return Buffer(ptr, size, self)
557601

558602
def deallocate(self, ptr: int, size: int, stream=None) -> None:
@@ -615,7 +659,7 @@ def reuse_follow_event_dependencies(self) -> bool:
615659
return bool(
616660
handle_return(
617661
driver.cuMemPoolGetAttribute(
618-
self._handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES
662+
self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES
619663
)
620664
)
621665
)
@@ -626,7 +670,7 @@ def reuse_allow_opportunistic(self) -> bool:
626670
return bool(
627671
handle_return(
628672
driver.cuMemPoolGetAttribute(
629-
self._handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC
673+
self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC
630674
)
631675
)
632676
)
@@ -637,7 +681,7 @@ def reuse_allow_internal_dependencies(self) -> bool:
637681
return bool(
638682
handle_return(
639683
driver.cuMemPoolGetAttribute(
640-
self._handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES
684+
self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES
641685
)
642686
)
643687
)
@@ -647,7 +691,9 @@ def release_threshold(self) -> int:
647691
"""Amount of reserved memory to hold before OS release."""
648692
return int(
649693
handle_return(
650-
driver.cuMemPoolGetAttribute(self._handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_RELEASE_THRESHOLD)
694+
driver.cuMemPoolGetAttribute(
695+
self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_RELEASE_THRESHOLD
696+
)
651697
)
652698
)
653699

@@ -657,7 +703,7 @@ def reserved_mem_current(self) -> int:
657703
return int(
658704
handle_return(
659705
driver.cuMemPoolGetAttribute(
660-
self._handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT
706+
self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT
661707
)
662708
)
663709
)
@@ -667,7 +713,9 @@ def reserved_mem_high(self) -> int:
667713
"""High watermark of backing memory allocated."""
668714
return int(
669715
handle_return(
670-
driver.cuMemPoolGetAttribute(self._handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH)
716+
driver.cuMemPoolGetAttribute(
717+
self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH
718+
)
671719
)
672720
)
673721

@@ -676,7 +724,9 @@ def used_mem_current(self) -> int:
676724
"""Current amount of memory in use."""
677725
return int(
678726
handle_return(
679-
driver.cuMemPoolGetAttribute(self._handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_USED_MEM_CURRENT)
727+
driver.cuMemPoolGetAttribute(
728+
self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_USED_MEM_CURRENT
729+
)
680730
)
681731
)
682732

@@ -685,7 +735,9 @@ def used_mem_high(self) -> int:
685735
"""High watermark of memory in use."""
686736
return int(
687737
handle_return(
688-
driver.cuMemPoolGetAttribute(self._handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_USED_MEM_HIGH)
738+
driver.cuMemPoolGetAttribute(
739+
self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_USED_MEM_HIGH
740+
)
689741
)
690742
)
691743

0 commit comments

Comments
 (0)