@@ -213,7 +213,8 @@ def __release_buffer__(self, buffer: memoryview, /):
213213class IPCBuffer (Buffer ):
214214 """Buffer class to represent a buffer description which can be shared across processes.
215215 It is not a valid buffer containing data, but rather a description used by the importing
216- process to construct a valid buffer."""
216+ process to construct a valid buffer. It's primary use is to provide a serialization
217+ mechanism for passing exported buffers between processes."""
217218
218219 def __init__ (self , reserved : bytes , size ):
219220 super ().__init__ (0 , 0 )
@@ -248,10 +249,36 @@ def _reconstruct(cls, reserved, size):
248249
249250
250251class MemoryResource (abc .ABC ):
251- __slots__ = ("_handle" ,)
252+ """Base class for memory resources.
253+
254+ This class provides an abstract interface for memory resources and includes
255+ an optional destruction mechanism through _MembersNeededForFinalize.
256+ """
257+
258+ class _MembersNeededForFinalize :
259+ __slots__ = ("handle" , "is_initialized" )
260+
261+ def __init__ (self , mr_obj , handle = None ):
262+ self .handle = handle
263+ self .is_initialized = False
264+ weakref .finalize (mr_obj , self .close )
265+
266+ def close (self ):
267+ if self .is_initialized and self .handle :
268+ # Specific cleanup can be implemented by derived classes
269+ self .handle = None
270+ self .is_initialized = False
271+
272+ __slots__ = ("__weakref__" , "_mnff" )
252273
253274 @abc .abstractmethod
254- def __init__ (self , * args , ** kwargs ): ...
275+ def __init__ (self , * args , ** kwargs ):
276+ self ._mnff = MemoryResource ._MembersNeededForFinalize (self )
277+
278+ def close (self ):
279+ """Release any resources associated with this memory resource."""
280+ if hasattr (self , "_mnff" ):
281+ self ._mnff .close ()
255282
256283 @abc .abstractmethod
257284 def allocate (self , size , stream = None ) -> Buffer : ...
@@ -327,7 +354,22 @@ class Mempool(MemoryResource):
327354 from_shared_handle : Import an existing memory pool from another process
328355 """
329356
330- __slots__ = ("_dev_id" , "_handle" , "_ipc_enabled" )
357+ class _MembersNeededForFinalize :
358+ __slots__ = ("handle" , "is_initialized" )
359+
360+ def __init__ (self , mr_obj , handle = None ):
361+ self .handle = handle
362+ weakref .finalize (mr_obj , self .close )
363+
364+ def close (self ):
365+ if self .is_initialized and self .handle :
366+ handle_return (driver .cuMemPoolDestroy (self .handle ))
367+ self .handle = None
368+
369+ __slots__ = (
370+ "__weakref__" ,
371+ "_mnff" ,
372+ )
331373
332374 def __init__ (self ):
333375 """Direct instantiation is not supported.
@@ -360,8 +402,8 @@ def _init(dev_id: int, handle: int, ipc_enabled: bool) -> Mempool:
360402 """
361403 self = Mempool .__new__ (Mempool )
362404 self ._dev_id = dev_id
363- self ._handle = handle
364405 self ._ipc_enabled = ipc_enabled
406+ self ._mnff = Mempool ._MembersNeededForFinalize (self , handle )
365407 return self
366408
367409 @staticmethod
@@ -463,7 +505,7 @@ def get_shareable_handle(self) -> int:
463505 """
464506 if not self ._ipc_enabled :
465507 raise RuntimeError ("This memory pool was not created with IPC support enabled" )
466- return handle_return (driver .cuMemPoolExportToShareableHandle (self ._handle , _get_platform_handle_type (), 0 ))
508+ return handle_return (driver .cuMemPoolExportToShareableHandle (self ._mnff . handle , _get_platform_handle_type (), 0 ))
467509
468510 def export_buffer (self , buffer : Buffer ) -> IPCBuffer :
469511 """Export a buffer allocated from this pool for sharing between processes.
@@ -528,7 +570,9 @@ def import_buffer(self, ipc_buffer: IPCBuffer) -> Buffer:
528570 raise RuntimeError ("This memory pool was not created with IPC support enabled" )
529571 share_data = driver .CUmemPoolPtrExportData ()
530572 share_data .reserved = ipc_buffer .reserved
531- return Buffer (handle_return (driver .cuMemPoolImportPointer (self ._handle , share_data )), ipc_buffer ._size , self )
573+ return Buffer (
574+ handle_return (driver .cuMemPoolImportPointer (self ._mnff .handle , share_data )), ipc_buffer ._size , self
575+ )
532576
533577 def allocate (self , size : int , stream = None ) -> Buffer :
534578 """Allocate memory from the pool.
@@ -552,7 +596,7 @@ def allocate(self, size: int, stream=None) -> Buffer:
552596 """
553597 if stream is None :
554598 stream = default_stream ()
555- ptr = handle_return (driver .cuMemAllocFromPoolAsync (size , self ._handle , stream .handle ))
599+ ptr = handle_return (driver .cuMemAllocFromPoolAsync (size , self ._mnff . handle , stream .handle ))
556600 return Buffer (ptr , size , self )
557601
558602 def deallocate (self , ptr : int , size : int , stream = None ) -> None :
@@ -615,7 +659,7 @@ def reuse_follow_event_dependencies(self) -> bool:
615659 return bool (
616660 handle_return (
617661 driver .cuMemPoolGetAttribute (
618- self ._handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES
662+ self ._mnff . handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES
619663 )
620664 )
621665 )
@@ -626,7 +670,7 @@ def reuse_allow_opportunistic(self) -> bool:
626670 return bool (
627671 handle_return (
628672 driver .cuMemPoolGetAttribute (
629- self ._handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC
673+ self ._mnff . handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC
630674 )
631675 )
632676 )
@@ -637,7 +681,7 @@ def reuse_allow_internal_dependencies(self) -> bool:
637681 return bool (
638682 handle_return (
639683 driver .cuMemPoolGetAttribute (
640- self ._handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES
684+ self ._mnff . handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES
641685 )
642686 )
643687 )
@@ -647,7 +691,9 @@ def release_threshold(self) -> int:
647691 """Amount of reserved memory to hold before OS release."""
648692 return int (
649693 handle_return (
650- driver .cuMemPoolGetAttribute (self ._handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_RELEASE_THRESHOLD )
694+ driver .cuMemPoolGetAttribute (
695+ self ._mnff .handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_RELEASE_THRESHOLD
696+ )
651697 )
652698 )
653699
@@ -657,7 +703,7 @@ def reserved_mem_current(self) -> int:
657703 return int (
658704 handle_return (
659705 driver .cuMemPoolGetAttribute (
660- self ._handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT
706+ self ._mnff . handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT
661707 )
662708 )
663709 )
@@ -667,7 +713,9 @@ def reserved_mem_high(self) -> int:
667713 """High watermark of backing memory allocated."""
668714 return int (
669715 handle_return (
670- driver .cuMemPoolGetAttribute (self ._handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH )
716+ driver .cuMemPoolGetAttribute (
717+ self ._mnff .handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH
718+ )
671719 )
672720 )
673721
@@ -676,7 +724,9 @@ def used_mem_current(self) -> int:
676724 """Current amount of memory in use."""
677725 return int (
678726 handle_return (
679- driver .cuMemPoolGetAttribute (self ._handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_USED_MEM_CURRENT )
727+ driver .cuMemPoolGetAttribute (
728+ self ._mnff .handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_USED_MEM_CURRENT
729+ )
680730 )
681731 )
682732
@@ -685,7 +735,9 @@ def used_mem_high(self) -> int:
685735 """High watermark of memory in use."""
686736 return int (
687737 handle_return (
688- driver .cuMemPoolGetAttribute (self ._handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_USED_MEM_HIGH )
738+ driver .cuMemPoolGetAttribute (
739+ self ._mnff .handle , driver .CUmemPool_attribute .CU_MEMPOOL_ATTR_USED_MEM_HIGH
740+ )
689741 )
690742 )
691743
0 commit comments