@@ -120,6 +120,33 @@ def is_torch_array(x):
120120    # TODO: Should we reject ndarray subclasses? 
121121    return  isinstance (x , torch .Tensor )
122122
123+ def  is_paddle_array (x ):
124+     """ 
125+     Return True if `x` is a Paddle tensor. 
126+ 
127+     This function does not import Paddle if it has not already been imported 
128+     and is therefore cheap to use. 
129+ 
130+     See Also 
131+     -------- 
132+ 
133+     array_namespace 
134+     is_array_api_obj 
135+     is_numpy_array 
136+     is_cupy_array 
137+     is_dask_array 
138+     is_jax_array 
139+     is_pydata_sparse_array 
140+     """ 
141+     # Avoid importing paddle if it isn't already 
142+     if  'paddle'  not  in sys .modules :
143+         return  False 
144+ 
145+     import  paddle 
146+ 
147+     # TODO: Should we reject ndarray subclasses? 
148+     return  paddle .is_tensor (x )
149+ 
123150def  is_ndonnx_array (x ):
124151    """ 
125152    Return True if `x` is a ndonnx Array. 
@@ -252,6 +279,7 @@ def is_array_api_obj(x):
252279        or  is_dask_array (x ) \
253280        or  is_jax_array (x ) \
254281        or  is_pydata_sparse_array (x ) \
282+         or  is_paddle_array (x ) \
255283        or  hasattr (x , '__array_namespace__' )
256284
257285def  _compat_module_name ():
@@ -319,6 +347,27 @@ def is_torch_namespace(xp) -> bool:
319347    return  xp .__name__  in  {'torch' , _compat_module_name () +  '.torch' }
320348
321349
350+ def  is_paddle_namespace (xp ) ->  bool :
351+     """ 
352+     Returns True if `xp` is a Paddle namespace. 
353+ 
354+     This includes both Paddle itself and the version wrapped by array-api-compat. 
355+ 
356+     See Also 
357+     -------- 
358+ 
359+     array_namespace 
360+     is_numpy_namespace 
361+     is_cupy_namespace 
362+     is_ndonnx_namespace 
363+     is_dask_namespace 
364+     is_jax_namespace 
365+     is_pydata_sparse_namespace 
366+     is_array_api_strict_namespace 
367+     """ 
368+     return  xp .__name__  in  {'paddle' , _compat_module_name () +  '.paddle' }
369+ 
370+ 
322371def  is_ndonnx_namespace (xp ):
323372    """ 
324373    Returns True if `xp` is an NDONNX namespace. 
@@ -543,6 +592,14 @@ def your_function(x, y):
543592                else :
544593                    import  jax .experimental .array_api  as  jnp 
545594            namespaces .add (jnp )
595+         elif  is_paddle_array (x ):
596+             if  _use_compat :
597+                 _check_api_version (api_version )
598+                 from  .. import  paddle  as  paddle_namespace 
599+                 namespaces .add (paddle_namespace )
600+             else :
601+                 import  paddle 
602+                 namespaces .add (paddle )
546603        elif  is_pydata_sparse_array (x ):
547604            if  use_compat  is  True :
548605                _check_api_version (api_version )
@@ -660,6 +717,16 @@ def device(x: Array, /) -> Device:
660717            return  "cpu" 
661718        # Return the device of the constituent array 
662719        return  device (inner )
720+     elif  is_paddle_array (x ):
721+         raw_place_str  =  str (x .place )
722+         if  "gpu_pinned"  in  raw_place_str :
723+             return  "cpu" 
724+         elif  "cpu"  in  raw_place_str :
725+             return  "cpu" 
726+         elif  "gpu"  in  raw_place_str :
727+             return  "gpu" 
728+         raise  NotImplementedError (f"Unsupported device { raw_place_str }  )
729+ 
663730    return  x .device 
664731
665732# Prevent shadowing, used below 
@@ -709,6 +776,14 @@ def _torch_to_device(x, device, /, stream=None):
709776        raise  NotImplementedError 
710777    return  x .to (device )
711778
779+ def  _paddle_to_device (x , device , / , stream = None ):
780+     if  stream  is  not None :
781+         raise  NotImplementedError (
782+             "paddle.Tensor.to() do not support stream argument yet" 
783+         )
784+     return  x .to (device )
785+ 
786+ 
712787def  to_device (x : Array , device : Device , / , * , stream : Optional [Union [int , Any ]] =  None ) ->  Array :
713788    """ 
714789    Copy the array from the device on which it currently resides to the specified ``device``. 
@@ -781,6 +856,8 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
781856            # In JAX v0.4.31 and older, this import adds to_device method to x. 
782857            import  jax .experimental .array_api  # noqa: F401 
783858        return  x .to_device (device , stream = stream )
859+     elif  is_paddle_array (x ):
860+         return  _paddle_to_device (x , device , stream = stream )
784861    elif  is_pydata_sparse_array (x ) and  device  ==  _device (x ):
785862        # Perform trivial check to return the same array if 
786863        # device is same instead of err-ing. 
@@ -819,6 +896,8 @@ def size(x):
819896    "is_torch_namespace" ,
820897    "is_ndonnx_array" ,
821898    "is_ndonnx_namespace" ,
899+     "is_paddle_array" ,
900+     "is_paddle_namespace" ,
822901    "is_pydata_sparse_array" ,
823902    "is_pydata_sparse_namespace" ,
824903    "size" ,
0 commit comments