Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytensor/link/mlx/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@
import pytensor.link.mlx.dispatch.signal
import pytensor.link.mlx.dispatch.signal.conv
import pytensor.link.mlx.dispatch.blockwise
import pytensor.link.mlx.dispatch.extra_ops
import pytensor.link.mlx.dispatch.sort
# isort: on
35 changes: 35 additions & 0 deletions pytensor/link/mlx/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import mlx.core as mx

from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.extra_ops import CumOp, Repeat


@mlx_funcify.register(CumOp)
def mlx_funcify_CumOp(op, **kwargs):
axis = op.axis
mode = op.mode

def cumop(x, axis=axis, mode=mode):
match mode:
case "add":
return mx.cumsum(x, axis=axis)
case "mul":
return mx.cumprod(x, axis=axis)
case _:
raise NotImplementedError(f"CumOp mode {mode} not implemented in MLX")

return cumop


@mlx_funcify.register(Repeat)
def jax_funcify_Repeat(op, **kwargs):
axis = op.axis

def repeat(x, repeats, axis=axis):
if not isinstance(repeats, int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is known at dispatch time, raise then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeats is a symbolic input. We only know axis as dispatch time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a weird mlx-specific limitation. There might be a work-around, but I don't want to do it in this PR. Just getting some common cases covered.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you know whether repeats is 0d or 1d at dispatch time? actually isn't the op always 1d now and we use alloc for 0d?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How? It's a symbolic input

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh by checking op.inputs[1] :fivehead:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by checking node.inputs. Node is the second argument to all dispatches

raise NotImplementedError(
"MLX repeat does not support sequence-valued repeat argument."
)
return mx.repeat(x, repeats, axis=axis)

return repeat
32 changes: 32 additions & 0 deletions pytensor/link/mlx/dispatch/sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import mlx.core as mx

from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.sort import ArgSortOp, SortOp


@mlx_funcify.register(SortOp)
def mlx_funcify_Sort(op, **kwargs):
kind = op.kind

def sort(x, axis, kind=kind):
if kind != "quicksort":
raise NotImplementedError(
f"MLX sort does not support kind={kind}, only 'quicksort'."
)
return mx.sort(x, axis=axis)

return sort


@mlx_funcify.register(ArgSortOp)
def mlx_funcify_ArgSort(op, **kwargs):
kind = op.kind

def argsort(x, axis, kind=kind):
if kind != "quicksort":
raise NotImplementedError(
f"MLX argsort does not support kind={kind}, only 'quicksort'."
)
return mx.argsort(x, axis=axis)

return argsort
24 changes: 24 additions & 0 deletions tests/link/mlx/test_extra_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np
import pytest

from pytensor.configdefaults import config
from pytensor.tensor import extra_ops as pt_extra_ops
from pytensor.tensor.type import matrix
from tests.link.mlx.test_basic import compare_mlx_and_py


mx = pytest.importorskip("mlx.core")


def test_extra_ops():
a = matrix("a")
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))

out = pt_extra_ops.cumsum(a, axis=0)
compare_mlx_and_py([a], [out], [a_test])

out = pt_extra_ops.cumprod(a, axis=1)
compare_mlx_and_py([a], [out], [a_test])

out = pt_extra_ops.repeat(a, 3, axis=1)
compare_mlx_and_py([a], [out], [a_test])
15 changes: 15 additions & 0 deletions tests/link/mlx/test_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np
import pytest

from pytensor.tensor.sort import argsort, sort
from pytensor.tensor.type import matrix
from tests.link.mlx.test_basic import compare_mlx_and_py


@pytest.mark.parametrize("axis", [None, -1])
@pytest.mark.parametrize("func", (sort, argsort))
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis)
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_mlx_and_py([x], [out], [arr])
Loading