add more compile() overloads -- up to arity 8 on input and 4 on output#394
add more compile() overloads -- up to arity 8 on input and 4 on output#394davidkoski merged 1 commit intomainfrom
Conversation
- this should cover _most_ cases - there is still ([MLXArray]) -> [MLXArray] as a fallback
| @@ -0,0 +1,532 @@ | |||
| // Copyright © 2026 Apple Inc. | |||
|
|
|||
| import Cmlx | |||
There was a problem hiding this comment.
We already have overloads for e.g. (MLXArray, MLXArray) -> MLXArray. This adds more complete coverage. I also experimented with macros, but nothing compelling.
Current tech allows something like this:
private func gelu(_ x: MLXArray) -> MLXArray {
x * (1 + erf(x / sqrt(2))) / 2
}
let compiledGelu = compile(gelu)
// use
let x: MLXArray
let result = compiledGelu(x)Pretty simple and now supports more arguments and return values.
For macros I tried some variants:
let compiledGelu2 = #MLXCompile({
(x: MLXArray) -> MLXArray in
x * (1 + erf(x / sqrt(2))) / 2
})
// use
let x: MLXArray
let result = compiledGelu2(x)Not bad -- we can generate the pack/unpack to [MLXArray] in the macro, which is nice. Not as nice is that we have to use closure syntax and not even trailing closure (outside the parens). I don't think it provides a compelling solution.
Another variant:
@MLXCompile({ (x: MLXArray) -> MLXArray in
x * (1 + erf(x / sqrt(2))) / 2
})
public func compiledGelu3(_ x: MLXArray) -> MLXArray
// use
let x: MLXArray
let result = compiledGelu2(x)The nice thing about this is you are defining a function and can put labels on the arguments, e.g. gelu(x: x), if you want. The downside is there is even more ceremony -- you are defining the arguments twice.
So I think what we have is pretty good.
There was a problem hiding this comment.
Yeah I agree the macros are not helping here. I left a comment on a potential extra option not sure if it is useful in swift. Also let me know if you want me to expand on the python implementation.
angeloskath
left a comment
There was a problem hiding this comment.
I like this!
Not sure if it would be useful and swift natural to maybe pass and receive dictionaries from string to array (I think string to Any would be annoying in Swift but maybe I am wrong).
So perhaps it makes sense to be able to compile
private func complexFunc(_ args: [String: MLXArray]) -> MLXArray {
args["foo"] * args["bar"]
}and/or dictionaries as return values as well.
In the python side we have a slightly complicated mechanism but it does provide a very good experience as one can just compile any function taking any arbitrary arguments in any standard container including constants of basic values.
Maybe, but the return type would be It does seem useful as an option for sure. |
Proposed changes
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes