Skip to content

add more compile() overloads -- up to arity 8 on input and 4 on output#394

Merged
davidkoski merged 1 commit intomainfrom
compile-overloads
Apr 13, 2026
Merged

add more compile() overloads -- up to arity 8 on input and 4 on output#394
davidkoski merged 1 commit intomainfrom
compile-overloads

Conversation

@davidkoski
Copy link
Copy Markdown
Collaborator

Proposed changes

  • this should cover most cases
  • there is still ([MLXArray]) -> [MLXArray] as a fallback

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

- this should cover _most_ cases
- there is still ([MLXArray]) -> [MLXArray] as a fallback
@davidkoski davidkoski requested a review from angeloskath April 10, 2026 20:22
@@ -0,0 +1,532 @@
// Copyright © 2026 Apple Inc.

import Cmlx
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have overloads for e.g. (MLXArray, MLXArray) -> MLXArray. This adds more complete coverage. I also experimented with macros, but nothing compelling.

Current tech allows something like this:

private func gelu(_ x: MLXArray) -> MLXArray {
    x * (1 + erf(x / sqrt(2))) / 2
}

let compiledGelu = compile(gelu)

// use
let x: MLXArray
let result = compiledGelu(x)

Pretty simple and now supports more arguments and return values.

For macros I tried some variants:

let compiledGelu2 = #MLXCompile({
    (x: MLXArray) -> MLXArray in
    x * (1 + erf(x / sqrt(2))) / 2
})

// use
let x: MLXArray
let result = compiledGelu2(x)

Not bad -- we can generate the pack/unpack to [MLXArray] in the macro, which is nice. Not as nice is that we have to use closure syntax and not even trailing closure (outside the parens). I don't think it provides a compelling solution.

Another variant:

@MLXCompile({ (x: MLXArray) -> MLXArray in
    x * (1 + erf(x / sqrt(2))) / 2
})
public func compiledGelu3(_ x: MLXArray) -> MLXArray

// use
let x: MLXArray
let result = compiledGelu2(x)

The nice thing about this is you are defining a function and can put labels on the arguments, e.g. gelu(x: x), if you want. The downside is there is even more ceremony -- you are defining the arguments twice.

So I think what we have is pretty good.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree the macros are not helping here. I left a comment on a potential extra option not sure if it is useful in swift. Also let me know if you want me to expand on the python implementation.

Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this!

Not sure if it would be useful and swift natural to maybe pass and receive dictionaries from string to array (I think string to Any would be annoying in Swift but maybe I am wrong).

So perhaps it makes sense to be able to compile

private func complexFunc(_ args: [String: MLXArray]) -> MLXArray {
    args["foo"] * args["bar"]
}

and/or dictionaries as return values as well.

In the python side we have a slightly complicated mechanism but it does provide a very good experience as one can just compile any function taking any arbitrary arguments in any standard container including constants of basic values.

@davidkoski
Copy link
Copy Markdown
Collaborator Author

So perhaps it makes sense to be able to compile

private func complexFunc(_ args: [String: MLXArray]) -> MLXArray {
    args["foo"] * args["bar"]
}

and/or dictionaries as return values as well.

Maybe, but the return type would be MLXArray? and you would either have to have error handling on that or have to args["foo"]! (force unwrap, basically a hard assert).

It does seem useful as an option for sure.

@davidkoski davidkoski merged commit 09fdeff into main Apr 13, 2026
8 checks passed
@davidkoski davidkoski deleted the compile-overloads branch April 13, 2026 23:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants