Skip to content

[Experiment] ROCm backend#2300

Open
NripeshN wants to merge 165 commits intoml-explore:mainfrom
NripeshN:rocm-support
Open

[Experiment] ROCm backend#2300
NripeshN wants to merge 165 commits intoml-explore:mainfrom
NripeshN:rocm-support

Conversation

@NripeshN
Copy link
Copy Markdown
Contributor

@NripeshN NripeshN commented Jun 16, 2025

Experiment with ROCm backend.

install MLX with ROCm backend using:

mkdir build && cd build
cmake -DMLX_BUILD_ROCM=ON \
      -DCMAKE_PREFIX_PATH=/opt/rocm \
      -DCMAKE_HIP_ARCHITECTURES="gfx90a;gfx1100" \
      ..
make -j$(nproc)

closes #2556

Inspired by @zcbenz

@NripeshN NripeshN changed the title [Experiment] ROCm backend initial push [Experiment] ROCm backend Jun 16, 2025
@lin72h
Copy link
Copy Markdown

lin72h commented Jun 17, 2025

What an unexpected and amazing surprise! I'm absolutely thrilled.

@NripeshN
Copy link
Copy Markdown
Contributor Author

@awni
What do you think of this PR? Does this have the potential to be merged into main? I can turn this PR from experimental to WIP if so.

@angeloskath
Copy link
Copy Markdown
Member

I think this is good to stay as an experiment branch for some time while we work on core and CUDA. I don't think we have the bandwidth to merge this for a few months at least. Sorry if this is disappointing @NripeshN I don't mean to discourage you working on it.

@akshat2602
Copy link
Copy Markdown

I would love to see the ROCm backend get more traction. The new AI series of processors by AMD have a similar advantage to Apple Silicon with unified memory and getting MLX to run on those processors would be neat.

@countradooku
Copy link
Copy Markdown

Stole my idea :(

@goniz
Copy link
Copy Markdown

goniz commented Jan 22, 2026

How is this even possible for such an awesome PR to be left like this?

Copilot AI review requested due to automatic review settings January 24, 2026 17:08
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds experimental ROCm backend support to MLX, enabling execution on AMD GPUs. The implementation mirrors the CUDA backend structure, providing HIP-based implementations of core operations, memory management, and device handling.

Changes:

  • Added ROCm backend infrastructure with device management, memory allocation, and stream handling
  • Implemented HIP kernels for unary, binary, ternary operations, reductions, normalization (softmax, layer_norm, rms_norm), RoPE, and sorting
  • Updated build system (CMake) to support ROCm compilation with configurable GPU architectures

Reviewed changes

Copilot reviewed 59 out of 59 changed files in this pull request and generated 13 comments.

Show a summary per file
File Description
CMakeLists.txt Added MLX_BUILD_ROCM option and ROCm library detection
mlx/CMakeLists.txt Integrated ROCm backend build configuration
mlx/device.cpp Added ROCm device availability checks
mlx/backend/rocm/*.hip HIP kernel implementations for various operations
mlx/backend/rocm/device.* ROCm device and stream management
mlx/backend/rocm/allocator.* ROCm-specific memory allocator using HIP unified memory
mlx/backend/rocm/worker.* Async task execution worker for stream synchronization
mlx/backend/rocm/utils.* HIP utility functions and error handling
mlx/backend/rocm/jit_module.* JIT compilation support using HIPRTC
mlx/backend/rocm/device/*.hpp Device-side utility functions and type definitions
mlx/backend/rocm/CMakeLists.txt ROCm backend build configuration

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

…ather, scatter, logsumexp, random bits generation, and sorting. Introduce new kernels for efficient computation and integrate with existing ROCm utilities. Update CMake configuration to include new source files and dependencies. Enhance error handling and ensure compatibility with different data types. This commit significantly expands the functionality of the ROCm backend.
@goniz
Copy link
Copy Markdown

goniz commented Jan 24, 2026

👑👑👑

@NripeshN
Copy link
Copy Markdown
Contributor Author

Can anyone run

CMAKE_ARGS="-DMLX_BUILD_ROCM=ON" pip install -e .
CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES={based on your GPU}" pip install -e .

Replace {based on your GPU} with your GPU architecture

You can run

rocm-smi

to get your GPU information

@goniz
Copy link
Copy Markdown

goniz commented Jan 24, 2026

I'm getting this CMake error:

CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES=gfx1151" pip install -e .

      -- Configuring done (4.8s)
      CMake Error: The following variables are used in this project, but they are set to NOTFOUND.
      Please set them or make sure they are set and tested correctly in the CMake files:
      /home/goniz/Work/mlx/LAPACK_INCLUDE_DIRS
         used as include directory in directory /home/goniz/Work/mlx
      
      CMake Error in CMakeLists.txt:
        HIP_ARCHITECTURES is empty for target "mlx".
      
      
      CMake Error in CMakeLists.txt:
        HIP_ARCHITECTURES is empty for target "mlx".
      
      
      -- Generating done (0.0s)
      CMake Generate step failed.  Build files cannot be regene
rated correctly.

Running on Strix Halo (gfx1151)

@NripeshN
Copy link
Copy Markdown
Contributor Author

I'm getting this CMake error:

CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES=gfx1151" pip install -e .
     -- Configuring done (4.8s)
     CMake Error: The following variables are used in this project, but they are set to NOTFOUND.
     Please set them or make sure they are set and tested correctly in the CMake files:
     /home/goniz/Work/mlx/LAPACK_INCLUDE_DIRS
        used as include directory in directory /home/goniz/Work/mlx
     
     CMake Error in CMakeLists.txt:
       HIP_ARCHITECTURES is empty for target "mlx".
     
     
     CMake Error in CMakeLists.txt:
       HIP_ARCHITECTURES is empty for target "mlx".
     
     
     -- Generating done (0.0s)
     CMake Generate step failed.  Build files cannot be regene
rated correctly.

Running on Strix Halo (gfx1151)

Could you retry with the latest push please (p.s. keep your fingers crossed while it compiles, worked for me 138th time)😅

… string formatting, replacing fmt library usage. Remove unused event.cpp file. Update kernel name generation and parameter formatting for consistency.
@goniz
Copy link
Copy Markdown

goniz commented Jan 25, 2026

  Created wheel for mlx: filename=mlx-0.30.4.dev20260125+cadf18c1-0.editable-cp314-cp314-linux_x86_64.whl size=4722 sha256=72c664adbfc4fb9ec317522a8d83b84f85d599d08bd691d7fec3abfdb6f3a5e9
  Stored in directory: /tmp/pip-ephem-wheel-cache-nt7w6bq0/wheels/8a/63/d1/d7d629a5ff73457822bb71aa527c083674bb19ca314735cd05
Successfully built mlx
Installing collected packages: mlx
Successfully installed mlx-0.30.4.dev20260125+cadf18c1

Now what can I test? 😍

@goniz
Copy link
Copy Markdown

goniz commented Jan 25, 2026

I'm getting this:

ImportError: /home/goniz/Work/mlx/python/mlx/lib/libmlx.so: undefined symbol: _ZN3mlx4core11Convolution8eval_gpuERKSt6vectorINS0_5arrayESaIS3_EERS3_

@NripeshN
Copy link
Copy Markdown
Contributor Author

I'm getting this:

ImportError: /home/goniz/Work/mlx/python/mlx/lib/libmlx.so: undefined symbol: _ZN3mlx4core11Convolution8eval_gpuERKSt6vectorINS0_5arrayESaIS3_EERS3_

I forgot to test the Python build my bad, can you try it now?

Unfortunately I might not be able to help after it compiles, I don't have an AMD GPU to run tests😔 I've tried replicating most things from cuda, so hopefully it works

goniz and others added 27 commits March 1, 2026 15:46
Use a hash of the module name for hiprtcCreateProgram to avoid
filesystem filename limits when HIP runtime compiler creates
temporary files. Also add get_hsaco_path() helper to split long
module names into nested directories for disk caching.

This fixes JIT compilation failures with complex fused kernels
that generate very long module names (>255 chars).
HIP doesn't provide native math functions for hip_bfloat16 and __half,
so add device function overloads that convert to float, compute, and
convert back. This enables JIT-compiled kernels to use math operations
on reduced-precision tensors.

Functions added: abs, exp, log, sqrt, rsqrt, sin, cos, tan, sinh, cosh,
tanh, asin, acos, atan, asinh, acosh, atanh, ceil, floor, rint, log2,
log10, log1pf, expm1f, erff, erfinvf, powf, fmodf, truncf, atan2f.
Add has_only_singleton_batch_dims() helper to correctly detect when

broadcasted singleton dimensions can be treated as non-batched matrices,

fixing page faults and incorrect results in certain quantized matmul cases.
- Add qmv_warp_shared_batched_kernel to optimize batched QMV with singleton dimensions.
- Add gather_qmv_warp_shared_kernel to accelerate MoE gather operations during decode.
- Update dispatch logic in QuantizedMatmul::eval_gpu and GatherQMM::eval_gpu to use these fast paths.
Improves decoding speed for 4-bit and 6-bit quantized models by 10-15%.
By reading up to 8 quantized values at once using uint32_t vector loads,
we better saturate the memory bandwidth instead of doing multiple byte-sized
loads. Also unskips passing tests in rocm_skip.py.
Tuning the number of threads per column to 16 rather than full WARP_SIZE
significantly improves decoding generation performance (from 14.5 to 18.2 TPS
on GLM-4 6bit) due to better hardware occupancy and register usage.
- Use sincosf() instead of separate cosf() + sinf() calls for better performance
- Add optimized 1D kernels (rope_single_1d, rope_single_freqs_1d) for single-token decode
- Use 256-thread 1D blocks instead of 16x16 2D blocks for small workloads
- Inline implementation in 1D kernels to reduce function call overhead

The decode case (B=1, T=1) now uses flat indexing which provides better
occupancy for the small number of elements typical in LLM decode steps.
Tune quantized matmul path selection for decode/prefill shapes, add bounded dequant cache with safe source retention, and wire QMV block sizing heuristics. Extend ROCm SDPA/flash dispatch to head dim 256 and add a pointwise conv fast path to reduce launch overhead in decode-like workloads.
Key dequant-cache entries by GPU buffer pointers to avoid stale hits from array-id reuse, and align QMV thread/column defaults with architecture-aware warp sizing across both QMM and GatherQMM paths.
Prefer flash SDPA for decode-like BF16/F16 configurations with long KV cache and no masks, while preserving vector fallback behavior. Also skip the AddMM input copy when beta is zero to eliminate redundant device-to-device copy work.
Allow strided-batched GEMM when collapsed batch dimensions are uniformly strided (including flattened multi-dimensional batches) instead of restricting to single-dimension batches only. This reduces fallback per-batch launch overhead and keeps more matmuls on the rocBLAS batched path.
Add env-configurable rocBLAS solution-index selection for float32 and bfloat16 GEMM/strided-batched GEMM paths across matmul, quantized QMM dequant GEMM, and shared rocBLAS wrappers. Keep default behavior unchanged (index 0), and automatically fall back to standard algorithms if a configured solution index fails.
Select QMV threads-per-column based on problem size instead of forcing warp-size on RDNA, and tune cols-per-block accordingly for 8-bit paths. This restores better out-of-box decode throughput on smaller models while preserving faster large-model defaults.
Use a larger shared-memory chunk (2048 vs 1024) in QMV warp-shared kernels to reduce chunk loop overhead and synchronization frequency. This improves out-of-box decode throughput on Qwen3.5 models without requiring runtime tuning knobs.
Deduplicate temporary buffer keepalive entries per command buffer to lower host-side bookkeeping and callback payload size, and raise the default max-ops-per-buffer threshold to reduce commit frequency on decode workloads.
@Geramy
Copy link
Copy Markdown

Geramy commented Mar 27, 2026

I have a lot of changes to merge in I am testing my port of mlx-swift-lm https://github.com/lemonade-sdk/lemon-mlx-engine against the mlx rocm core, https://github.com/lemonade-sdk/lemon-mlx-core-amd I got qwen3 working. I am working on Qwen3Next right now, its having weird issues. There are tons of problems with the rocm backend that I have traced to "different rounding" causing unstable outputs. But a lot of it is fixed now at least regarding qwen models. Once I get your changes merged into my repo I will then push a PR into yours with my changes. I have made optimizations as well, there are problems with the fallback system when functions in rocBLAS arn't compatible or existent for the architecture.

@Geramy
Copy link
Copy Markdown

Geramy commented Mar 27, 2026

Once I get Qwen3Next working at a reasonable speed I will do the PR.

@NripeshN
Copy link
Copy Markdown
Contributor Author

I have a lot of changes to merge

I have added you as a collaborator on my fork, you should be able to push changes directly to this branch(should be able to push changes directly to this PR). Again amazing work🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add ROCm Support for AMD GPUs

8 participants