Skip to content

Fix vmap + floor_divide: preserve integer dtype#3292

Merged
zcbenz merged 4 commits intoml-explore:mainfrom
robert-johansson:fix-vmap-floor-divide
Mar 24, 2026
Merged

Fix vmap + floor_divide: preserve integer dtype#3292
zcbenz merged 4 commits intoml-explore:mainfrom
robert-johansson:fix-vmap-floor-divide

Conversation

@robert-johansson
Copy link
Copy Markdown
Contributor

Summary

Divide::vmap calls divide(a, b) which promotes integers to float via at_least_float. But floor_divide creates a Divide primitive with integer dtype for integer division. Under vmap, the integer dtype is lost.

Before:

import mlx.core as mx
f = lambda s: mx.floor_divide(s, mx.array(5, mx.int32))
vf = mx.vmap(f)
result = vf(mx.arange(25, dtype=mx.int32))
# result.dtype == float32, values are [0.0, 0.2, 0.4, ...] (wrong)

After:

result = vf(mx.arange(25, dtype=mx.int32))
# result.dtype == int32, values are [0, 0, 0, 0, 0, 1, 1, ...] (correct)

Fix

In Divide::vmap, check if inputs are integer type and construct the Divide primitive directly (preserving dtype) instead of calling the high-level divide() which promotes to float.

Test plan

  • New test case test vmap floor_divide integer with 3 sub-tests:
    • floor_divide preserves int32 dtype under vmap
    • remainder preserves int32 dtype under vmap
    • floor_divide + remainder reconstructs original values under vmap
  • All existing vmap tests pass (38 assertions)

Robert Johansson and others added 2 commits March 21, 2026 21:33
Divide::vmap called divide() which promotes integers to float via
at_least_float. But floor_divide creates a Divide primitive with
integer dtype for integer division. The vmap of this primitive
should preserve integer semantics.

Fix: in Divide::vmap, check if inputs are integer type and construct
the Divide primitive directly (preserving dtype) instead of calling
the high-level divide() which promotes to float.

Before: vmap(floor_divide(12, 5)) → 2.4 (float32)
After:  vmap(floor_divide(12, 5)) → 2   (int32)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would make vmap + divide output wrong dtype for integers? Adding a flag in Divide might be a better solution?

@zcbenz zcbenz merged commit e18d4e9 into ml-explore:main Mar 24, 2026
16 checks passed
robert-johansson pushed a commit to robert-johansson/genmlx that referenced this pull request Mar 27, 2026
node-mlx now pins mlx at origin/main which contains latest
upstream (0ff1115a) + 6 GenMLX-only patches. The vmap/floor_divide
fix was dropped (merged upstream as ml-explore/mlx#3292).

git submodule update --init --recursive now works without needing
to know about any special branches.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants