22
33from  typing  import  TYPE_CHECKING 
44if  TYPE_CHECKING :
5-     import  paddle 
6-     array  =  paddle .Tensor 
7-     from  paddle  import  dtype  as  Dtype 
5+     import  torch 
6+     array  =  torch .Tensor 
7+     from  torch  import  dtype  as  Dtype 
88    from  typing  import  Optional , Union , Tuple , Literal 
99    inf  =  float ('inf' )
1010
1111from  ._aliases  import  _fix_promotion , sum 
1212
13- from  paddle .linalg  import  *  # noqa: F403 
13+ from  torch .linalg  import  *  # noqa: F403 
1414
15- # paddle .linalg doesn't define __all__ 
16- # from paddle .linalg import __all__ as linalg_all 
17- from  paddle  import  linalg  as  paddle_linalg 
18- linalg_all  =  [i  for  i  in  dir (paddle_linalg ) if  not  i .startswith ('_' )]
15+ # torch .linalg doesn't define __all__ 
16+ # from torch .linalg import __all__ as linalg_all 
17+ from  torch  import  linalg  as  torch_linalg 
18+ linalg_all  =  [i  for  i  in  dir (torch_linalg ) if  not  i .startswith ('_' )]
1919
20- # outer is implemented in paddle  but aren't in the linalg namespace 
21- from  paddle  import  outer 
20+ # outer is implemented in torch  but aren't in the linalg namespace 
21+ from  torch  import  outer 
2222# These functions are in both the main and linalg namespaces 
2323from  ._aliases  import  matmul , matrix_transpose , tensordot 
2424
25- # Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the 
25+ # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the 
26+ # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 
2627
27- # paddle.cross also does not support broadcasting when it would add new 
28+ # torch.cross also does not support broadcasting when it would add new 
29+ # dimensions https://github.com/pytorch/pytorch/issues/39656 
2830def  cross (x1 : array , x2 : array , / , * , axis : int  =  - 1 ) ->  array :
2931    x1 , x2  =  _fix_promotion (x1 , x2 , only_scalar = False )
3032    if  not  (- min (x1 .ndim , x2 .ndim ) <=  axis  <  max (x1 .ndim , x2 .ndim )):
3133        raise  ValueError (f"axis { axis } { x1 .shape } { x2 .shape }  )
3234    if  not  (x1 .shape [axis ] ==  x2 .shape [axis ] ==  3 ):
3335        raise  ValueError (f"cross product axis must have size 3, got { x1 .shape [axis ]} { x2 .shape [axis ]}  )
34-     x1 , x2  =  paddle .broadcast_tensors (x1 , x2 )
35-     return  paddle_linalg .cross (x1 , x2 , axis = axis )
36+     x1 , x2  =  torch .broadcast_tensors (x1 , x2 )
37+     return  torch_linalg .cross (x1 , x2 , dim = axis )
3638
3739def  vecdot (x1 : array , x2 : array , / , * , axis : int  =  - 1 , ** kwargs ) ->  array :
3840    from  ._aliases  import  isdtype 
3941
4042    x1 , x2  =  _fix_promotion (x1 , x2 , only_scalar = False )
4143
42-     # paddle .linalg.vecdot incorrectly allows broadcasting along the contracted dimension 
44+     # torch .linalg.vecdot incorrectly allows broadcasting along the contracted dimension 
4345    if  x1 .shape [axis ] !=  x2 .shape [axis ]:
4446        raise  ValueError ("x1 and x2 must have the same size along the given axis" )
4547
46-     # paddle .linalg.vecdot doesn't support integer dtypes 
48+     # torch .linalg.vecdot doesn't support integer dtypes 
4749    if  isdtype (x1 .dtype , 'integral' ) or  isdtype (x2 .dtype , 'integral' ):
4850        if  kwargs :
4951            raise  RuntimeError ("vecdot kwargs not supported for integral dtypes" )
5052
51-         x1_  =  paddle .moveaxis (x1 , axis , - 1 )
52-         x2_  =  paddle .moveaxis (x2 , axis , - 1 )
53-         x1_ , x2_  =  paddle .broadcast_tensors (x1_ , x2_ )
53+         x1_  =  torch .moveaxis (x1 , axis , - 1 )
54+         x2_  =  torch .moveaxis (x2 , axis , - 1 )
55+         x1_ , x2_  =  torch .broadcast_tensors (x1_ , x2_ )
5456
5557        res  =  x1_ [..., None , :] @ x2_ [..., None ]
5658        return  res [..., 0 , 0 ]
57-     return  paddle .linalg .vecdot (x1 , x2 , axis = axis , ** kwargs )
59+     return  torch .linalg .vecdot (x1 , x2 , dim = axis , ** kwargs )
5860
5961def  solve (x1 : array , x2 : array , / , ** kwargs ) ->  array :
6062    x1 , x2  =  _fix_promotion (x1 , x2 , only_scalar = False )
61-     # paddle  tries to emulate NumPy 1 solve behavior by using batched 1-D solve 
63+     # Torch  tries to emulate NumPy 1 solve behavior by using batched 1-D solve 
6264    # whenever 
6365    # 1. x1.ndim - 1 == x2.ndim 
6466    # 2. x1.shape[:-1] == x2.shape 
6567    # 
6668    # See linalg_solve_is_vector_rhs in 
6769    # aten/src/ATen/native/LinearAlgebraUtils.h and 
68-     # paddle_META_FUNC (_linalg_solve_ex) in 
69-     # aten/src/ATen/native/BatchLinearAlgebra.cpp in the Pypaddle  source code. 
70+     # TORCH_META_FUNC (_linalg_solve_ex) in 
71+     # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch  source code. 
7072    # 
7173    # The easiest way to work around this is to prepend a size 1 dimension to 
7274    # x2, since x2 is already one dimension less than x1. 
7375    # 
74-     # See https://github.com/pypaddle/pypaddle /issues/52915 
76+     # See https://github.com/pytorch/pytorch /issues/52915 
7577    if  x2 .ndim  !=  1  and  x1 .ndim  -  1  ==  x2 .ndim  and  x1 .shape [:- 1 ] ==  x2 .shape :
7678        x2  =  x2 [None ]
77-     return  paddle .linalg .solve (x1 , x2 , ** kwargs )
79+     return  torch .linalg .solve (x1 , x2 , ** kwargs )
7880
79- # paddle .trace doesn't support the offset argument and doesn't support stacking 
81+ # torch .trace doesn't support the offset argument and doesn't support stacking 
8082def  trace (x : array , / , * , offset : int  =  0 , dtype : Optional [Dtype ] =  None ) ->  array :
8183    # Use our wrapped sum to make sure it does upcasting correctly 
82-     return  sum (paddle .diagonal (x , offset = offset , dim1 = - 2 , dim2 = - 1 ), axis = - 1 , dtype = dtype )
84+     return  sum (torch .diagonal (x , offset = offset , dim1 = - 2 , dim2 = - 1 ), axis = - 1 , dtype = dtype )
8385
8486def  vector_norm (
8587    x : array ,
@@ -90,30 +92,30 @@ def vector_norm(
9092    ord : Union [int , float , Literal [inf , - inf ]] =  2 ,
9193    ** kwargs ,
9294) ->  array :
93-     # paddle .vector_norm incorrectly treats axis=() the same as axis=None 
95+     # torch .vector_norm incorrectly treats axis=() the same as axis=None 
9496    if  axis  ==  ():
9597        out  =  kwargs .get ('out' )
9698        if  out  is  None :
9799            dtype  =  None 
98-             if  x .dtype  ==  paddle .complex64 :
99-                 dtype  =  paddle .float32 
100-             elif  x .dtype  ==  paddle .complex128 :
101-                 dtype  =  paddle .float64 
100+             if  x .dtype  ==  torch .complex64 :
101+                 dtype  =  torch .float32 
102+             elif  x .dtype  ==  torch .complex128 :
103+                 dtype  =  torch .float64 
102104
103-             out  =  paddle .zeros_like (x , dtype = dtype )
105+             out  =  torch .zeros_like (x , dtype = dtype )
104106
105107        # The norm of a single scalar works out to abs(x) in every case except 
106-         # for p =0, which is x != 0. 
108+         # for ord =0, which is x != 0. 
107109        if  ord  ==  0 :
108110            out [:] =  (x  !=  0 )
109111        else :
110-             out [:] =  paddle .abs (x )
112+             out [:] =  torch .abs (x )
111113        return  out 
112-     return  paddle .linalg .vector_norm (x , p = ord , axis = axis , keepdim = keepdims , ** kwargs )
114+     return  torch .linalg .vector_norm (x , ord = ord , axis = axis , keepdim = keepdims , ** kwargs )
113115
114116__all__  =  linalg_all  +  ['outer' , 'matmul' , 'matrix_transpose' , 'tensordot' ,
115117                        'cross' , 'vecdot' , 'solve' , 'trace' , 'vector_norm' ]
116118
117- _all_ignore  =  ['paddle_linalg ' , 'sum' ]
119+ _all_ignore  =  ['torch_linalg ' , 'sum' ]
118120
119121del  linalg_all 
0 commit comments