Skip to content

Commit 115c865

Browse files
committed
Make Dot only accept matrix inputs
1 parent 5c01ab6 commit 115c865

File tree

17 files changed

+128
-417
lines changed

17 files changed

+128
-417
lines changed

pytensor/tensor/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
107107
from pytensor.tensor import (
108108
blas,
109109
blas_c,
110-
blas_scipy,
111110
sharedvar,
112111
xlogx,
113112
)

pytensor/tensor/basic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,8 +1801,7 @@ def do_constant_folding(self, fgraph, node):
18011801
| pytensor.tensor.blas.Gemv
18021802
| pytensor.tensor.blas_c.CGemv
18031803
| pytensor.tensor.blas.Ger
1804-
| pytensor.tensor.blas_c.CGer
1805-
| pytensor.tensor.blas_scipy.ScipyGer,
1804+
| pytensor.tensor.blas_c.CGer,
18061805
)
18071806
):
18081807
# Ops that will work inplace on the Alloc. So if they

pytensor/tensor/blas.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
from pathlib import Path
8484

8585
import numpy as np
86+
from scipy.linalg import get_blas_funcs
8687

8788
from pytensor.graph import vectorize_graph
8889
from pytensor.npy_2_compat import normalize_axis_tuple
@@ -288,18 +289,17 @@ def make_node(self, A, alpha, x, y):
288289

289290
return Apply(self, inputs, [A.type()])
290291

291-
def perform(self, node, inp, out):
292-
cA, calpha, cx, cy = inp
293-
(cZ,) = out
294-
if self.destructive:
295-
A = cA
296-
else:
297-
A = cA.copy()
298-
if calpha != 1:
299-
A += calpha * np.outer(cx, cy)
300-
else:
301-
A += np.outer(cx, cy)
302-
cZ[0] = A
292+
def perform(self, node, inputs, output_storage):
293+
A, alpha, x, y = inputs
294+
if A.size:
295+
# GER doesn't handle zero-sized inputs
296+
ger_func = get_blas_funcs("ger", dtype=A.dtype)
297+
if A.flags["C_CONTIGUOUS"]:
298+
# Work on transposed system to avoid copying
299+
A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T
300+
else:
301+
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
302+
output_storage[0][0] = A
303303

304304
def infer_shape(self, fgraph, node, input_shapes):
305305
return [input_shapes[0]]
@@ -1128,16 +1128,8 @@ def make_node(self, x, y):
11281128
outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
11291129
return Apply(self, [x, y], outputs)
11301130

1131-
def perform(self, node, inp, out):
1132-
x, y = inp
1133-
(z,) = out
1134-
try:
1135-
z[0] = np.asarray(np.dot(x, y))
1136-
except ValueError as e:
1137-
# The error raised by numpy has no shape information, we mean to
1138-
# add that
1139-
e.args = (*e.args, x.shape, y.shape)
1140-
raise
1131+
def perform(self, node, inputs, output_storage):
1132+
output_storage[0][0] = np.dot(*inputs)
11411133

11421134
def infer_shape(self, fgraph, node, input_shapes):
11431135
return [[input_shapes[0][0], input_shapes[1][1]]]

pytensor/tensor/blas_scipy.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

pytensor/tensor/math.py

Lines changed: 49 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@
4040
get_normalized_batch_axes,
4141
scalar_elemwise,
4242
)
43-
from pytensor.tensor.shape import shape, specify_broadcastable
43+
from pytensor.tensor.shape import shape, specify_shape
4444
from pytensor.tensor.type import (
4545
DenseTensorType,
4646
complex_dtypes,
4747
continuous_dtypes,
4848
discrete_dtypes,
49+
float_dtypes,
4950
int_dtypes,
5051
tensor,
5152
uint_dtypes,
@@ -2986,9 +2987,7 @@ def clip(x, min, max):
29862987

29872988
class Dot(Op):
29882989
"""
2989-
Computes the dot product of two variables. For two matrices, this is
2990-
equivalent to matrix multiplication. For two vectors, this is the inner
2991-
product.
2990+
Computes the dot product of two matrices variables
29922991
29932992
Notes
29942993
-----
@@ -3001,97 +3000,57 @@ class Dot(Op):
30013000
30023001
"""
30033002

3003+
gufunc_signature = "(m,n),(n,p)->(m,p)"
3004+
gufunc_spec = ("matmul", 2, 1)
30043005
__props__ = ()
30053006

3006-
# the rationale for Dot22 is related to getting GEMM Ops into the
3007-
# graph. See Dot22 in tensor.blas for details.
3008-
3009-
def make_node(self, *inputs):
3010-
inputs = list(map(as_tensor_variable, inputs))
3007+
def make_node(self, x, y):
3008+
x = as_tensor_variable(x)
3009+
y = as_tensor_variable(y)
30113010

3012-
if len(inputs) != 2:
3013-
raise TypeError(f"Two arguments required, {len(inputs)} given ")
3014-
if inputs[0].ndim not in (1, 2):
3011+
if x.type.ndim != 2:
30153012
raise TypeError(
3016-
"Input 0 (0-indexed) must have ndim of "
3017-
f"1 or 2, {int(inputs[0].ndim)} given. Consider calling "
3018-
"pytensor.tensor.dot instead."
3013+
f"Dot Op expects a 2D tensor as input 0, got {x} with {x.type.ndim} dimensions"
30193014
)
3020-
if inputs[1].ndim not in (1, 2):
3015+
if y.type.ndim != 2:
30213016
raise TypeError(
3022-
"Input 1 (0-indexed) must have ndim of "
3023-
f"1 or 2, {int(inputs[1].ndim)} given. Consider calling "
3024-
"pytensor.tensor.dot instead."
3017+
f"Dot Op expects a 2D tensor as input 1, got {y} with {y.type.ndim} dimensions"
30253018
)
30263019

3027-
sx, sy = (input.type.shape for input in inputs)
3020+
sx, sy = x.type.shape, y.type.shape
30283021
if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]:
30293022
raise ValueError(
30303023
f"Incompatible shared dimension for dot product: {sx}, {sy}"
30313024
)
3025+
sz = sx[:-1] + sy[-1:]
3026+
outputs = [tensor(dtype=ps.upcast(x.type.dtype, y.type.dtype), shape=sz)]
3027+
return Apply(self, [x, y], outputs)
30323028

3033-
if len(sy) == 2:
3034-
sz = sx[:-1] + sy[-1:]
3035-
elif len(sy) == 1:
3036-
sz = sx[:-1]
3037-
3038-
i_dtypes = [input.type.dtype for input in inputs]
3039-
outputs = [tensor(dtype=ps.upcast(*i_dtypes), shape=sz)]
3040-
return Apply(self, inputs, outputs)
3041-
3042-
def perform(self, node, inp, out):
3043-
x, y = inp
3044-
(z,) = out
3045-
3046-
# the asarray is here because dot between two vectors
3047-
# gives a numpy float object but we need to return a 0d
3048-
# ndarray
3049-
z[0] = np.asarray(np.dot(x, y))
3029+
def perform(self, node, inputs, output_storage):
3030+
output_storage[0][0] = np.matmul(*inputs)
30503031

30513032
def grad(self, inp, grads):
30523033
x, y = inp
30533034
(gz,) = grads
3054-
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
3055-
3056-
# grad is scalar, so x is vector and y is vector
3057-
if gdim == 0:
3058-
xgrad = gz * y
3059-
ygrad = gz * x
3060-
3061-
# x is vector, y is matrix, grad is vector
3062-
elif xdim == 1 and ydim == 2:
3063-
xgrad = dot(gz, y.T)
3064-
ygrad = outer(x.T, gz)
30653035

3066-
# x is matrix, y is vector, grad is vector
3067-
elif xdim == 2 and ydim == 1:
3068-
xgrad = outer(gz, y.T)
3069-
ygrad = dot(x.T, gz)
3070-
3071-
# x is matrix, y is matrix, grad is matrix
3072-
elif xdim == ydim == 2:
3073-
xgrad = dot(gz, y.T)
3074-
ygrad = dot(x.T, gz)
3036+
xgrad = self(gz, y.T)
3037+
ygrad = self(x.T, gz)
30753038

30763039
# If x or y contain broadcastable dimensions but only one of
30773040
# them know that a matching dimensions is broadcastable, the
30783041
# above code don't always return the right broadcast pattern.
30793042
# This cause problem down the road. See gh-1461.
3080-
if xgrad.broadcastable != x.broadcastable:
3081-
xgrad = specify_broadcastable(
3082-
xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b)
3083-
)
3084-
if ygrad.broadcastable != y.broadcastable:
3085-
ygrad = specify_broadcastable(
3086-
ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b)
3087-
)
3043+
if xgrad.type.shape != x.type.shape:
3044+
xgrad = specify_shape(xgrad, x.type.shape)
3045+
if ygrad.type.shape != y.type.shape:
3046+
ygrad = specify_shape(ygrad, y.type.shape)
30883047

3089-
rval = xgrad, ygrad
3048+
if xgrad.type.dtype not in float_dtypes:
3049+
raise TypeError("Dot grad x output must be a float type")
3050+
if ygrad.type.dtype not in float_dtypes:
3051+
raise TypeError("Dot grad y output must be a float type")
30903052

3091-
for elem in rval:
3092-
assert elem.dtype.find("float") != -1
3093-
3094-
return rval
3053+
return xgrad, ygrad
30953054

30963055
def R_op(self, inputs, eval_points):
30973056
# R_op for a \dot b evaluated at c for a and d for b is
@@ -3116,24 +3075,7 @@ def R_op(self, inputs, eval_points):
31163075

31173076
def infer_shape(self, fgraph, node, shapes):
31183077
xshp, yshp = shapes
3119-
x, y = node.inputs
3120-
3121-
# vector / vector
3122-
if x.ndim == 1 and y.ndim == 1:
3123-
return [()]
3124-
# matrix / vector
3125-
if x.ndim == 2 and y.ndim == 1:
3126-
return [xshp[:-1]]
3127-
# vector / matrix
3128-
if x.ndim == 1 and y.ndim == 2:
3129-
return [yshp[-1:]]
3130-
# matrix / matrix
3131-
if x.ndim == 2 and y.ndim == 2:
3132-
return [xshp[:-1] + yshp[-1:]]
3133-
raise NotImplementedError()
3134-
3135-
def __str__(self):
3136-
return "dot"
3078+
return [[xshp[0], yshp[1]]]
31373079

31383080

31393081
_dot = Dot()
@@ -3215,7 +3157,24 @@ def dense_dot(a, b):
32153157
elif a.ndim > 2 or b.ndim > 2:
32163158
return tensordot(a, b, [[a.ndim - 1], [np.maximum(0, b.ndim - 2)]])
32173159
else:
3218-
return _dot(a, b)
3160+
row_vector = a.ndim == 1
3161+
if row_vector:
3162+
# Promote to row matrix
3163+
a = a[None]
3164+
3165+
col_vector = b.ndim == 1
3166+
if col_vector:
3167+
# Promote to column matrix
3168+
b = b[:, None]
3169+
3170+
out = _dot(a, b)
3171+
if row_vector:
3172+
# If we promoted a to a row matrix, we need to squeeze the first dimension
3173+
out = out.squeeze(0)
3174+
if col_vector:
3175+
# If we promoted b to a column matrix, we need to squeeze the last dimension
3176+
out = out.squeeze(-1)
3177+
return out
32193178

32203179

32213180
def tensordot(
@@ -3921,11 +3880,7 @@ def logsumexp(x, axis=None, keepdims=False):
39213880
return log(sum(exp(x), axis=axis, keepdims=keepdims))
39223881

39233882

3924-
_matmul = Blockwise(
3925-
_dot,
3926-
signature="(m,k),(k,n)->(m,n)",
3927-
gufunc_spec=("numpy.matmul", 2, 1),
3928-
)
3883+
_matmul = Blockwise(_dot, name="Matmul")
39293884

39303885

39313886
def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):

pytensor/tensor/rewriting/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytensor.tensor.rewriting.basic
22
import pytensor.tensor.rewriting.blas
33
import pytensor.tensor.rewriting.blas_c
4-
import pytensor.tensor.rewriting.blas_scipy
54
import pytensor.tensor.rewriting.blockwise
65
import pytensor.tensor.rewriting.einsum
76
import pytensor.tensor.rewriting.elemwise

pytensor/tensor/rewriting/blas.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@
107107
)
108108
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
109109
from pytensor.tensor.type import (
110-
DenseTensorType,
111110
TensorType,
112111
integer_dtypes,
113112
values_eq_approx_remove_inf_nan,
@@ -580,29 +579,14 @@ def print_profile(cls, stream, prof, level=0):
580579
def local_dot_to_dot22(fgraph, node):
581580
# This works for tensor.outer too because basic.outer is a macro that
582581
# produces a dot(dimshuffle,dimshuffle) of form 4 below
583-
if not isinstance(node.op, Dot):
584-
return
585-
586-
if any(not isinstance(i.type, DenseTensorType) for i in node.inputs):
587-
return False
588-
589582
x, y = node.inputs
590583
if y.type.dtype != x.type.dtype:
591584
# TODO: upcast one so the types match
592585
_logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}")
593586
return
594587

595588
if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"):
596-
if x.ndim == 2 and y.ndim == 2:
597-
new_out = [_dot22(*node.inputs)]
598-
elif x.ndim == 2 and y.ndim == 1:
599-
new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)]
600-
elif x.ndim == 1 and y.ndim == 2:
601-
new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)]
602-
elif x.ndim == 1 and y.ndim == 1:
603-
new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()]
604-
else:
605-
return
589+
new_out = [_dot22(*node.inputs)]
606590
copy_stack_trace(node.outputs, new_out)
607591
return new_out
608592

0 commit comments

Comments
 (0)