Skip to content

Commit 1f9a67b

Browse files
Add linalg Ops to MLX backend (#1700)
Co-authored-by: jessegrabowski <jessegrabowski@gmail.com>
1 parent f83c05b commit 1f9a67b

File tree

5 files changed

+330
-0
lines changed

5 files changed

+330
-0
lines changed

pytensor/link/mlx/dispatch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@
1212
import pytensor.link.mlx.dispatch.blockwise
1313
import pytensor.link.mlx.dispatch.extra_ops
1414
import pytensor.link.mlx.dispatch.sort
15+
import pytensor.link.mlx.dispatch.slinalg
16+
import pytensor.link.mlx.dispatch.nlinalg
1517
# isort: on
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import mlx.core as mx
2+
3+
from pytensor.link.mlx.dispatch.basic import mlx_funcify
4+
from pytensor.tensor.nlinalg import SVD, KroneckerProduct, MatrixInverse, MatrixPinv
5+
6+
7+
@mlx_funcify.register(SVD)
8+
def mlx_funcify_SVD(op, node, **kwargs):
9+
full_matrices = op.full_matrices
10+
compute_uv = op.compute_uv
11+
12+
X_dtype = getattr(mx, node.inputs[0].dtype)
13+
14+
if not full_matrices:
15+
raise TypeError("full_matrices=False is not supported in the mlx backend.")
16+
17+
def svd_S_only(x):
18+
return mx.linalg.svd(
19+
x.astype(dtype=X_dtype, stream=mx.cpu), compute_uv=False, stream=mx.cpu
20+
)
21+
22+
def svd_full(x):
23+
outputs = mx.linalg.svd(
24+
x.astype(dtype=X_dtype, stream=mx.cpu), compute_uv=True, stream=mx.cpu
25+
)
26+
return outputs
27+
28+
if compute_uv:
29+
return svd_full
30+
else:
31+
return svd_S_only
32+
33+
34+
@mlx_funcify.register(KroneckerProduct)
35+
def mlx_funcify_KroneckerProduct(op, node, **kwargs):
36+
otype = node.outputs[0].dtype
37+
stream = mx.cpu if otype == "float64" else mx.gpu
38+
39+
A_dtype = getattr(mx, node.inputs[0].dtype)
40+
B_dtype = getattr(mx, node.inputs[1].dtype)
41+
42+
def kron(a, b):
43+
return mx.kron(
44+
a.astype(dtype=A_dtype, stream=stream),
45+
b.astype(dtype=B_dtype, stream=stream),
46+
stream=stream,
47+
)
48+
49+
return kron
50+
51+
52+
@mlx_funcify.register(MatrixInverse)
53+
def mlx_funcify_MatrixInverse(op, node, **kwargs):
54+
X_dtype = getattr(mx, node.inputs[0].dtype)
55+
56+
def inv(x):
57+
return mx.linalg.inv(x.astype(dtype=X_dtype, stream=mx.cpu), stream=mx.cpu)
58+
59+
return inv
60+
61+
62+
@mlx_funcify.register(MatrixPinv)
63+
def mlx_funcify_MatrixPinv(op, node, **kwargs):
64+
x_dtype = getattr(mx, node.inputs[0].dtype)
65+
66+
def pinv(x):
67+
return mx.linalg.pinv(x.astype(dtype=x_dtype, stream=mx.cpu), stream=mx.cpu)
68+
69+
return pinv
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import warnings
2+
3+
import mlx.core as mx
4+
5+
from pytensor.link.mlx.dispatch.basic import mlx_funcify
6+
from pytensor.tensor.slinalg import LU, Cholesky, Solve, SolveTriangular
7+
8+
9+
@mlx_funcify.register(Cholesky)
10+
def mlx_funcify_Cholesky(op, node, **kwargs):
11+
lower = op.lower
12+
a_dtype = getattr(mx, node.inputs[0].dtype)
13+
14+
def cholesky(a):
15+
return mx.linalg.cholesky(
16+
a.astype(dtype=a_dtype, stream=mx.cpu), upper=not lower, stream=mx.cpu
17+
)
18+
19+
return cholesky
20+
21+
22+
@mlx_funcify.register(Solve)
23+
def mlx_funcify_Solve(op, node, **kwargs):
24+
assume_a = op.assume_a
25+
a_dtype = getattr(mx, node.inputs[0].dtype)
26+
b_dtype = getattr(mx, node.inputs[1].dtype)
27+
28+
if assume_a != "gen":
29+
warnings.warn(
30+
f"MLX solve does not support assume_a={op.assume_a}. Defaulting to assume_a='gen'.",
31+
UserWarning,
32+
)
33+
34+
def solve(a, b):
35+
# MLX only supports solve on CPU
36+
return mx.linalg.solve(
37+
a.astype(stream=mx.cpu, dtype=a_dtype),
38+
b.astype(stream=mx.cpu, dtype=b_dtype),
39+
stream=mx.cpu,
40+
)
41+
42+
return solve
43+
44+
45+
@mlx_funcify.register(SolveTriangular)
46+
def mlx_funcify_SolveTriangular(op, node, **kwargs):
47+
lower = op.lower
48+
A_dtype = getattr(mx, node.inputs[0].dtype)
49+
b_dtype = getattr(mx, node.inputs[1].dtype)
50+
51+
def solve_triangular(A, b):
52+
return mx.linalg.solve_triangular(
53+
A.astype(stream=mx.cpu, dtype=A_dtype),
54+
b.astype(stream=mx.cpu, dtype=b_dtype),
55+
upper=not lower,
56+
stream=mx.cpu, # MLX only supports solve_triangular on CPU
57+
)
58+
59+
return solve_triangular
60+
61+
62+
@mlx_funcify.register(LU)
63+
def mlx_funcify_LU(op, node, **kwargs):
64+
permute_l = op.permute_l
65+
A_dtype = getattr(mx, node.inputs[0].dtype)
66+
p_indices = op.p_indices
67+
68+
if permute_l:
69+
raise ValueError("permute_l=True is not supported in the mlx backend.")
70+
if not p_indices:
71+
raise ValueError("p_indices=False is not supported in the mlx backend.")
72+
73+
def lu(a):
74+
p_idx, L, U = mx.linalg.lu(
75+
a.astype(dtype=A_dtype, stream=mx.cpu), stream=mx.cpu
76+
)
77+
78+
return (
79+
p_idx.astype(mx.int32, stream=mx.cpu),
80+
L,
81+
U,
82+
)
83+
84+
return lu

tests/link/mlx/test_nlinalg.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from functools import partial
2+
3+
import numpy as np
4+
import pytest
5+
6+
import pytensor.tensor as pt
7+
from pytensor import config
8+
from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode
9+
10+
11+
@pytest.mark.parametrize("compute_uv", [True, False])
12+
def test_mlx_svd(compute_uv):
13+
rng = np.random.default_rng(15)
14+
15+
A = pt.matrix(name="X")
16+
A_val = rng.normal(size=(3, 3)).astype(config.floatX)
17+
A_val = A_val @ A_val.T
18+
19+
out = pt.linalg.svd(A, compute_uv=compute_uv)
20+
21+
compare_mlx_and_py(
22+
[A],
23+
out,
24+
[A_val],
25+
mlx_mode=mlx_mode,
26+
assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True),
27+
)
28+
29+
30+
def test_mlx_kron():
31+
rng = np.random.default_rng(15)
32+
33+
A = pt.matrix(name="A")
34+
B = pt.matrix(name="B")
35+
A_val, B_val = rng.normal(scale=0.1, size=(2, 3, 3)).astype(config.floatX)
36+
out = pt.linalg.kron(A, B)
37+
38+
compare_mlx_and_py(
39+
[A, B],
40+
[out],
41+
[A_val, B_val],
42+
mlx_mode=mlx_mode,
43+
assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True),
44+
)
45+
46+
47+
@pytest.mark.parametrize("op", [pt.linalg.inv, pt.linalg.pinv], ids=["inv", "pinv"])
48+
def test_mlx_inv(op):
49+
rng = np.random.default_rng(15)
50+
n = 3
51+
52+
A = pt.matrix(name="A")
53+
A_val = rng.normal(size=(n, n))
54+
A_val = (A_val @ A_val.T).astype(config.floatX)
55+
56+
out = op(A)
57+
58+
compare_mlx_and_py(
59+
[A],
60+
[out],
61+
[A_val],
62+
mlx_mode=mlx_mode,
63+
assert_fn=partial(
64+
np.testing.assert_allclose, atol=1e-6, rtol=1e-6, strict=True
65+
),
66+
)

tests/link/mlx/test_slinalg.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import contextlib
2+
from functools import partial
3+
4+
import numpy as np
5+
import pytest
6+
7+
import pytensor.tensor as pt
8+
from pytensor import config
9+
from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode
10+
11+
12+
@pytest.mark.parametrize("lower", [True, False])
13+
def test_mlx_cholesky(lower):
14+
rng = np.random.default_rng(15)
15+
n = 3
16+
17+
A = pt.tensor("A", shape=(n, n))
18+
A_val = rng.normal(size=(n, n))
19+
A_val = (A_val @ A_val.T).astype(config.floatX)
20+
21+
out = pt.linalg.cholesky(A, lower=lower)
22+
23+
compare_mlx_and_py(
24+
[A],
25+
[out],
26+
[A_val],
27+
mlx_mode=mlx_mode,
28+
assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True),
29+
)
30+
31+
32+
@pytest.mark.parametrize("assume_a", ["gen", "pos"])
33+
def test_mlx_solve(assume_a):
34+
rng = np.random.default_rng(15)
35+
n = 3
36+
37+
A = pt.tensor("A", shape=(n, n))
38+
b = pt.tensor("B", shape=(n, n))
39+
40+
out = pt.linalg.solve(A, b, b_ndim=2, assume_a=assume_a)
41+
42+
A_val = rng.normal(size=(n, n)).astype(config.floatX)
43+
A_val = A_val @ A_val.T
44+
45+
b_val = rng.normal(size=(n, n)).astype(config.floatX)
46+
47+
context = (
48+
contextlib.suppress()
49+
if assume_a == "gen"
50+
else pytest.warns(
51+
UserWarning, match=f"MLX solve does not support assume_a={assume_a}"
52+
)
53+
)
54+
55+
with context:
56+
compare_mlx_and_py(
57+
[A, b],
58+
[out],
59+
[A_val, b_val],
60+
mlx_mode=mlx_mode,
61+
assert_fn=partial(
62+
np.testing.assert_allclose, atol=1e-6, rtol=1e-6, strict=True
63+
),
64+
)
65+
66+
67+
@pytest.mark.parametrize("lower, trans", [(False, False), (True, True)])
68+
def test_mlx_SolveTriangular(lower, trans):
69+
rng = np.random.default_rng(15)
70+
71+
A = pt.tensor("A", shape=(5, 5))
72+
b = pt.tensor("B", shape=(5, 5))
73+
74+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
75+
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
76+
77+
out = pt.linalg.solve_triangular(
78+
A,
79+
b,
80+
trans=0,
81+
lower=lower,
82+
unit_diagonal=False,
83+
)
84+
compare_mlx_and_py(
85+
[A, b],
86+
[out],
87+
[A_val, b_val],
88+
mlx_mode=mlx_mode,
89+
assert_fn=partial(
90+
np.testing.assert_allclose, atol=1e-6, rtol=1e-6, strict=True
91+
),
92+
)
93+
94+
95+
def test_mlx_LU():
96+
rng = np.random.default_rng(15)
97+
98+
A = pt.tensor("A", shape=(5, 5))
99+
out = pt.linalg.lu(A, permute_l=False, p_indices=True)
100+
101+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
102+
103+
compare_mlx_and_py(
104+
[A],
105+
out,
106+
[A_val],
107+
mlx_mode=mlx_mode,
108+
assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True),
109+
)

0 commit comments

Comments
 (0)