Conversation
|
What an unexpected and amazing surprise! I'm absolutely thrilled. |
|
@awni |
|
I think this is good to stay as an experiment branch for some time while we work on core and CUDA. I don't think we have the bandwidth to merge this for a few months at least. Sorry if this is disappointing @NripeshN I don't mean to discourage you working on it. |
|
I would love to see the ROCm backend get more traction. The new AI series of processors by AMD have a similar advantage to Apple Silicon with unified memory and getting MLX to run on those processors would be neat. |
|
Stole my idea :( |
|
How is this even possible for such an awesome PR to be left like this? |
There was a problem hiding this comment.
Pull request overview
This PR adds experimental ROCm backend support to MLX, enabling execution on AMD GPUs. The implementation mirrors the CUDA backend structure, providing HIP-based implementations of core operations, memory management, and device handling.
Changes:
- Added ROCm backend infrastructure with device management, memory allocation, and stream handling
- Implemented HIP kernels for unary, binary, ternary operations, reductions, normalization (softmax, layer_norm, rms_norm), RoPE, and sorting
- Updated build system (CMake) to support ROCm compilation with configurable GPU architectures
Reviewed changes
Copilot reviewed 59 out of 59 changed files in this pull request and generated 13 comments.
Show a summary per file
| File | Description |
|---|---|
| CMakeLists.txt | Added MLX_BUILD_ROCM option and ROCm library detection |
| mlx/CMakeLists.txt | Integrated ROCm backend build configuration |
| mlx/device.cpp | Added ROCm device availability checks |
| mlx/backend/rocm/*.hip | HIP kernel implementations for various operations |
| mlx/backend/rocm/device.* | ROCm device and stream management |
| mlx/backend/rocm/allocator.* | ROCm-specific memory allocator using HIP unified memory |
| mlx/backend/rocm/worker.* | Async task execution worker for stream synchronization |
| mlx/backend/rocm/utils.* | HIP utility functions and error handling |
| mlx/backend/rocm/jit_module.* | JIT compilation support using HIPRTC |
| mlx/backend/rocm/device/*.hpp | Device-side utility functions and type definitions |
| mlx/backend/rocm/CMakeLists.txt | ROCm backend build configuration |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…ather, scatter, logsumexp, random bits generation, and sorting. Introduce new kernels for efficient computation and integrate with existing ROCm utilities. Update CMake configuration to include new source files and dependencies. Enhance error handling and ensure compatibility with different data types. This commit significantly expands the functionality of the ROCm backend.
|
👑👑👑 |
|
Can anyone run CMAKE_ARGS="-DMLX_BUILD_ROCM=ON" pip install -e .
CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES={based on your GPU}" pip install -e .Replace {based on your GPU} with your GPU architecture You can run rocm-smito get your GPU information |
|
I'm getting this CMake error: Running on Strix Halo (gfx1151) |
Could you retry with the latest push please (p.s. keep your fingers crossed while it compiles, worked for me 138th time)😅 |
… string formatting, replacing fmt library usage. Remove unused event.cpp file. Update kernel name generation and parameter formatting for consistency.
Now what can I test? 😍 |
|
I'm getting this: |
I forgot to test the Python build my bad, can you try it now? Unfortunately I might not be able to help after it compiles, I don't have an AMD GPU to run tests😔 I've tried replicating most things from cuda, so hopefully it works |
Use a hash of the module name for hiprtcCreateProgram to avoid filesystem filename limits when HIP runtime compiler creates temporary files. Also add get_hsaco_path() helper to split long module names into nested directories for disk caching. This fixes JIT compilation failures with complex fused kernels that generate very long module names (>255 chars).
HIP doesn't provide native math functions for hip_bfloat16 and __half, so add device function overloads that convert to float, compute, and convert back. This enables JIT-compiled kernels to use math operations on reduced-precision tensors. Functions added: abs, exp, log, sqrt, rsqrt, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, ceil, floor, rint, log2, log10, log1pf, expm1f, erff, erfinvf, powf, fmodf, truncf, atan2f.
Add has_only_singleton_batch_dims() helper to correctly detect when broadcasted singleton dimensions can be treated as non-batched matrices, fixing page faults and incorrect results in certain quantized matmul cases.
- Add qmv_warp_shared_batched_kernel to optimize batched QMV with singleton dimensions. - Add gather_qmv_warp_shared_kernel to accelerate MoE gather operations during decode. - Update dispatch logic in QuantizedMatmul::eval_gpu and GatherQMM::eval_gpu to use these fast paths.
Improves decoding speed for 4-bit and 6-bit quantized models by 10-15%. By reading up to 8 quantized values at once using uint32_t vector loads, we better saturate the memory bandwidth instead of doing multiple byte-sized loads. Also unskips passing tests in rocm_skip.py.
Tuning the number of threads per column to 16 rather than full WARP_SIZE significantly improves decoding generation performance (from 14.5 to 18.2 TPS on GLM-4 6bit) due to better hardware occupancy and register usage.
- Use sincosf() instead of separate cosf() + sinf() calls for better performance - Add optimized 1D kernels (rope_single_1d, rope_single_freqs_1d) for single-token decode - Use 256-thread 1D blocks instead of 16x16 2D blocks for small workloads - Inline implementation in 1D kernels to reduce function call overhead The decode case (B=1, T=1) now uses flat indexing which provides better occupancy for the small number of elements typical in LLM decode steps.
Tune quantized matmul path selection for decode/prefill shapes, add bounded dequant cache with safe source retention, and wire QMV block sizing heuristics. Extend ROCm SDPA/flash dispatch to head dim 256 and add a pointwise conv fast path to reduce launch overhead in decode-like workloads.
Key dequant-cache entries by GPU buffer pointers to avoid stale hits from array-id reuse, and align QMV thread/column defaults with architecture-aware warp sizing across both QMM and GatherQMM paths.
Prefer flash SDPA for decode-like BF16/F16 configurations with long KV cache and no masks, while preserving vector fallback behavior. Also skip the AddMM input copy when beta is zero to eliminate redundant device-to-device copy work.
Allow strided-batched GEMM when collapsed batch dimensions are uniformly strided (including flattened multi-dimensional batches) instead of restricting to single-dimension batches only. This reduces fallback per-batch launch overhead and keeps more matmuls on the rocBLAS batched path.
Add env-configurable rocBLAS solution-index selection for float32 and bfloat16 GEMM/strided-batched GEMM paths across matmul, quantized QMM dequant GEMM, and shared rocBLAS wrappers. Keep default behavior unchanged (index 0), and automatically fall back to standard algorithms if a configured solution index fails.
Select QMV threads-per-column based on problem size instead of forcing warp-size on RDNA, and tune cols-per-block accordingly for 8-bit paths. This restores better out-of-box decode throughput on smaller models while preserving faster large-model defaults.
Use a larger shared-memory chunk (2048 vs 1024) in QMV warp-shared kernels to reduce chunk loop overhead and synchronization frequency. This improves out-of-box decode throughput on Qwen3.5 models without requiring runtime tuning knobs.
Deduplicate temporary buffer keepalive entries per command buffer to lower host-side bookkeeping and callback payload size, and raise the default max-ops-per-buffer threshold to reduce commit frequency on decode workloads.
|
I have a lot of changes to merge in I am testing my port of mlx-swift-lm https://github.com/lemonade-sdk/lemon-mlx-engine against the mlx rocm core, https://github.com/lemonade-sdk/lemon-mlx-core-amd I got qwen3 working. I am working on Qwen3Next right now, its having weird issues. There are tons of problems with the rocm backend that I have traced to "different rounding" causing unstable outputs. But a lot of it is fixed now at least regarding qwen models. Once I get your changes merged into my repo I will then push a PR into yours with my changes. I have made optimizations as well, there are problems with the fallback system when functions in rocBLAS arn't compatible or existent for the architecture. |
|
Once I get Qwen3Next working at a reasonable speed I will do the PR. |
I have added you as a collaborator on my fork, you should be able to push changes directly to this branch(should be able to push changes directly to this PR). Again amazing work🚀 |
Experiment with ROCm backend.
install MLX with ROCm backend using:
closes #2556
Inspired by @zcbenz