40
40
get_normalized_batch_axes ,
41
41
scalar_elemwise ,
42
42
)
43
- from pytensor .tensor .shape import shape , specify_broadcastable
43
+ from pytensor .tensor .shape import shape , specify_shape
44
44
from pytensor .tensor .type import (
45
45
DenseTensorType ,
46
46
complex_dtypes ,
47
47
continuous_dtypes ,
48
48
discrete_dtypes ,
49
+ float_dtypes ,
49
50
int_dtypes ,
50
51
tensor ,
51
52
uint_dtypes ,
@@ -2986,9 +2987,7 @@ def clip(x, min, max):
2986
2987
2987
2988
class Dot (Op ):
2988
2989
"""
2989
- Computes the dot product of two variables. For two matrices, this is
2990
- equivalent to matrix multiplication. For two vectors, this is the inner
2991
- product.
2990
+ Computes the dot product of two matrices variables
2992
2991
2993
2992
Notes
2994
2993
-----
@@ -3001,97 +3000,57 @@ class Dot(Op):
3001
3000
3002
3001
"""
3003
3002
3003
+ gufunc_signature = "(m,n),(n,p)->(m,p)"
3004
+ gufunc_spec = ("matmul" , 2 , 1 )
3004
3005
__props__ = ()
3005
3006
3006
- # the rationale for Dot22 is related to getting GEMM Ops into the
3007
- # graph. See Dot22 in tensor.blas for details.
3008
-
3009
- def make_node (self , * inputs ):
3010
- inputs = list (map (as_tensor_variable , inputs ))
3007
+ def make_node (self , x , y ):
3008
+ x = as_tensor_variable (x )
3009
+ y = as_tensor_variable (y )
3011
3010
3012
- if len (inputs ) != 2 :
3013
- raise TypeError (f"Two arguments required, { len (inputs )} given " )
3014
- if inputs [0 ].ndim not in (1 , 2 ):
3011
+ if x .type .ndim != 2 :
3015
3012
raise TypeError (
3016
- "Input 0 (0-indexed) must have ndim of "
3017
- f"1 or 2, { int (inputs [0 ].ndim )} given. Consider calling "
3018
- "pytensor.tensor.dot instead."
3013
+ f"Dot Op expects a 2D tensor as input 0, got { x } with { x .type .ndim } dimensions"
3019
3014
)
3020
- if inputs [ 1 ]. ndim not in ( 1 , 2 ) :
3015
+ if y . type . ndim != 2 :
3021
3016
raise TypeError (
3022
- "Input 1 (0-indexed) must have ndim of "
3023
- f"1 or 2, { int (inputs [1 ].ndim )} given. Consider calling "
3024
- "pytensor.tensor.dot instead."
3017
+ f"Dot Op expects a 2D tensor as input 1, got { y } with { y .type .ndim } dimensions"
3025
3018
)
3026
3019
3027
- sx , sy = ( input .type .shape for input in inputs )
3020
+ sx , sy = x .type .shape , y . type . shape
3028
3021
if sx [- 1 ] is not None and sy [0 ] is not None and sx [- 1 ] != sy [0 ]:
3029
3022
raise ValueError (
3030
3023
f"Incompatible shared dimension for dot product: { sx } , { sy } "
3031
3024
)
3025
+ sz = sx [:- 1 ] + sy [- 1 :]
3026
+ outputs = [tensor (dtype = ps .upcast (x .type .dtype , y .type .dtype ), shape = sz )]
3027
+ return Apply (self , [x , y ], outputs )
3032
3028
3033
- if len (sy ) == 2 :
3034
- sz = sx [:- 1 ] + sy [- 1 :]
3035
- elif len (sy ) == 1 :
3036
- sz = sx [:- 1 ]
3037
-
3038
- i_dtypes = [input .type .dtype for input in inputs ]
3039
- outputs = [tensor (dtype = ps .upcast (* i_dtypes ), shape = sz )]
3040
- return Apply (self , inputs , outputs )
3041
-
3042
- def perform (self , node , inp , out ):
3043
- x , y = inp
3044
- (z ,) = out
3045
-
3046
- # the asarray is here because dot between two vectors
3047
- # gives a numpy float object but we need to return a 0d
3048
- # ndarray
3049
- z [0 ] = np .asarray (np .dot (x , y ))
3029
+ def perform (self , node , inputs , output_storage ):
3030
+ output_storage [0 ][0 ] = np .matmul (* inputs )
3050
3031
3051
3032
def grad (self , inp , grads ):
3052
3033
x , y = inp
3053
3034
(gz ,) = grads
3054
- xdim , ydim , gdim = x .type .ndim , y .type .ndim , gz .type .ndim
3055
-
3056
- # grad is scalar, so x is vector and y is vector
3057
- if gdim == 0 :
3058
- xgrad = gz * y
3059
- ygrad = gz * x
3060
-
3061
- # x is vector, y is matrix, grad is vector
3062
- elif xdim == 1 and ydim == 2 :
3063
- xgrad = dot (gz , y .T )
3064
- ygrad = outer (x .T , gz )
3065
3035
3066
- # x is matrix, y is vector, grad is vector
3067
- elif xdim == 2 and ydim == 1 :
3068
- xgrad = outer (gz , y .T )
3069
- ygrad = dot (x .T , gz )
3070
-
3071
- # x is matrix, y is matrix, grad is matrix
3072
- elif xdim == ydim == 2 :
3073
- xgrad = dot (gz , y .T )
3074
- ygrad = dot (x .T , gz )
3036
+ xgrad = self (gz , y .T )
3037
+ ygrad = self (x .T , gz )
3075
3038
3076
3039
# If x or y contain broadcastable dimensions but only one of
3077
3040
# them know that a matching dimensions is broadcastable, the
3078
3041
# above code don't always return the right broadcast pattern.
3079
3042
# This cause problem down the road. See gh-1461.
3080
- if xgrad .broadcastable != x .broadcastable :
3081
- xgrad = specify_broadcastable (
3082
- xgrad , * (ax for (ax , b ) in enumerate (x .type .broadcastable ) if b )
3083
- )
3084
- if ygrad .broadcastable != y .broadcastable :
3085
- ygrad = specify_broadcastable (
3086
- ygrad , * (ax for (ax , b ) in enumerate (y .type .broadcastable ) if b )
3087
- )
3043
+ if xgrad .type .shape != x .type .shape :
3044
+ xgrad = specify_shape (xgrad , x .type .shape )
3045
+ if ygrad .type .shape != y .type .shape :
3046
+ ygrad = specify_shape (ygrad , y .type .shape )
3088
3047
3089
- rval = xgrad , ygrad
3048
+ if xgrad .type .dtype not in float_dtypes :
3049
+ raise TypeError ("Dot grad x output must be a float type" )
3050
+ if ygrad .type .dtype not in float_dtypes :
3051
+ raise TypeError ("Dot grad y output must be a float type" )
3090
3052
3091
- for elem in rval :
3092
- assert elem .dtype .find ("float" ) != - 1
3093
-
3094
- return rval
3053
+ return xgrad , ygrad
3095
3054
3096
3055
def R_op (self , inputs , eval_points ):
3097
3056
# R_op for a \dot b evaluated at c for a and d for b is
@@ -3116,24 +3075,7 @@ def R_op(self, inputs, eval_points):
3116
3075
3117
3076
def infer_shape (self , fgraph , node , shapes ):
3118
3077
xshp , yshp = shapes
3119
- x , y = node .inputs
3120
-
3121
- # vector / vector
3122
- if x .ndim == 1 and y .ndim == 1 :
3123
- return [()]
3124
- # matrix / vector
3125
- if x .ndim == 2 and y .ndim == 1 :
3126
- return [xshp [:- 1 ]]
3127
- # vector / matrix
3128
- if x .ndim == 1 and y .ndim == 2 :
3129
- return [yshp [- 1 :]]
3130
- # matrix / matrix
3131
- if x .ndim == 2 and y .ndim == 2 :
3132
- return [xshp [:- 1 ] + yshp [- 1 :]]
3133
- raise NotImplementedError ()
3134
-
3135
- def __str__ (self ):
3136
- return "dot"
3078
+ return [[xshp [0 ], yshp [1 ]]]
3137
3079
3138
3080
3139
3081
_dot = Dot ()
@@ -3215,7 +3157,24 @@ def dense_dot(a, b):
3215
3157
elif a .ndim > 2 or b .ndim > 2 :
3216
3158
return tensordot (a , b , [[a .ndim - 1 ], [np .maximum (0 , b .ndim - 2 )]])
3217
3159
else :
3218
- return _dot (a , b )
3160
+ row_vector = a .ndim == 1
3161
+ if row_vector :
3162
+ # Promote to row matrix
3163
+ a = a [None ]
3164
+
3165
+ col_vector = b .ndim == 1
3166
+ if col_vector :
3167
+ # Promote to column matrix
3168
+ b = b [:, None ]
3169
+
3170
+ out = _dot (a , b )
3171
+ if row_vector :
3172
+ # If we promoted a to a row matrix, we need to squeeze the first dimension
3173
+ out = out .squeeze (0 )
3174
+ if col_vector :
3175
+ # If we promoted b to a column matrix, we need to squeeze the last dimension
3176
+ out = out .squeeze (- 1 )
3177
+ return out
3219
3178
3220
3179
3221
3180
def tensordot (
@@ -3921,11 +3880,7 @@ def logsumexp(x, axis=None, keepdims=False):
3921
3880
return log (sum (exp (x ), axis = axis , keepdims = keepdims ))
3922
3881
3923
3882
3924
- _matmul = Blockwise (
3925
- _dot ,
3926
- signature = "(m,k),(k,n)->(m,n)" ,
3927
- gufunc_spec = ("numpy.matmul" , 2 , 1 ),
3928
- )
3883
+ _matmul = Blockwise (_dot , name = "Matmul" )
3929
3884
3930
3885
3931
3886
def matmul (x1 : "ArrayLike" , x2 : "ArrayLike" , dtype : Optional ["DTypeLike" ] = None ):
0 commit comments