Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
206 commits
Select commit Hold shift + click to select a range
8bb8b76
[Experiment] ROCM backend initial push
NripeshN Jun 16, 2025
ac5adfa
increment 1: few ops and jit update
NripeshN Jun 18, 2025
cc4de6a
Increment 2: Implement major ops and add structure similar to cuda
NripeshN Jun 18, 2025
1163da1
Merge remote-tracking branch 'upstream/main' into rocm-support
NripeshN Jan 24, 2026
667cd9b
rocm yaay
NripeshN Jan 24, 2026
8780ad9
Implement ROCm support for various operations including arg reduce, g…
NripeshN Jan 24, 2026
63d6b6a
chore fix cmake
NripeshN Jan 24, 2026
7c1b29d
Merge branch 'main' into rocm-support
NripeshN Jan 24, 2026
ee8b705
compile fix
NripeshN Jan 25, 2026
9aa0f5c
Refactor error handling in ROCm backend to use std::ostringstream for…
NripeshN Jan 25, 2026
cadf18c
lint
NripeshN Jan 25, 2026
6fa7c7c
add more features
NripeshN Jan 26, 2026
57941f9
Enhance ROCm backend with new features including binary operations, L…
NripeshN Jan 26, 2026
1856341
Remove optional MIOpen support from ROCm backend CMake configuration.…
NripeshN Jan 26, 2026
2e27dc9
Add scaled dot product attention kernel and update ROCm convolution i…
NripeshN Jan 26, 2026
da275f7
Fix symbol linking issue
NripeshN Jan 26, 2026
499d2a6
lazy load GPU
NripeshN Jan 26, 2026
c30b211
Add general gather and scatter kernels for arbitrary indexing in ROCm…
NripeshN Jan 26, 2026
86e4f85
Add dynamic copy kernel and gather operation in ROCm backend
NripeshN Jan 26, 2026
7141d8c
Add quantized matrix multiplication and gather QMM kernel in ROCm bac…
NripeshN Jan 26, 2026
1c74fba
Merge remote-tracking branch 'upstream/main' into rocm-support
NripeshN Jan 27, 2026
04efa16
Fix HIP include paths for C++ standard library headers
NripeshN Feb 3, 2026
bf993f8
Rewrite ROCm sort with custom merge sort implementation
NripeshN Feb 3, 2026
b76745e
Fix ROCm sort compilation errors
NripeshN Feb 3, 2026
969fd0b
Remove duplicate is_available() and unavailable header from ROCm eval…
Geramy Feb 3, 2026
b82594d
Add device_info.cpp for ROCm backend
Geramy Feb 3, 2026
231c078
Include memory.h in ROCm allocator for proper symbol visibility
Geramy Feb 3, 2026
8de6a7a
Fix all ROCm backend compiler warnings
NripeshN Feb 3, 2026
04b2e8d
Fix remaining ROCm backend compiler warnings
NripeshN Feb 3, 2026
bf3b69b
Add ROCm Python bindings and test skip list
NripeshN Feb 3, 2026
9af0755
Add MLX_API to rocm::is_available() for proper symbol export
NripeshN Feb 3, 2026
90377cc
Fix ROCm allocator to fall back to hipMalloc when managed memory fails
NripeshN Feb 3, 2026
b330ad1
Fix ROCm allocator to use hipHostMalloc when managed memory unavailable
NripeshN Feb 3, 2026
39b2926
Fix WARP_SIZE to be architecture-dependent for ROCm
NripeshN Feb 3, 2026
467fb00
Fix macro conflicts in WARP_SIZE and MAX_NDIM definitions
NripeshN Feb 3, 2026
4545bac
Fix WARP_SIZE_ROW namespace reference
NripeshN Feb 3, 2026
6e6d837
Fix MAX_NDIM macro reference in compiled.cpp
NripeshN Feb 3, 2026
54c8833
Fix cross-type copy for ROCm backend
NripeshN Feb 4, 2026
1adfed0
Fix ROCm copy and arg_reduce for correct warp size
NripeshN Feb 4, 2026
7d554b0
Fix CMAKE_HIP_ARCHITECTURES to respect user-provided value
NripeshN Feb 4, 2026
df4d228
Fix MAX_NDIM conflict and restore dispatch_all_types for copy
NripeshN Feb 4, 2026
4746543
Add proper CastOp for ROCm copy to handle all type conversions
NripeshN Feb 4, 2026
aa4ff37
Add missing half/bfloat16 conversions in CastOp
NripeshN Feb 4, 2026
6e4d799
Remove duplicate is_complex definition, use from utils.hpp
NripeshN Feb 4, 2026
97afbd5
Improve ROCm backend to match CUDA functionality
NripeshN Feb 4, 2026
ad9c9cc
Fix reduce operations to match CUDA type constraints
NripeshN Feb 4, 2026
5269e6a
Fix Max/Min reduce ops to use explicit specializations instead of con…
NripeshN Feb 4, 2026
6e4e202
Exclude complex types from reduce operations (not yet supported on ROCm)
NripeshN Feb 4, 2026
4aec5ec
Fix type_identity usage - use mlx::core::type_identity instead of std…
NripeshN Feb 4, 2026
a17961e
Include reduce_utils.hpp for allocate_same_layout
NripeshN Feb 4, 2026
216e533
Enhance ROCm support in CMake and backend
NripeshN Feb 4, 2026
4bf5f22
Add hipFloatComplex support for scan and reduce operations
NripeshN Feb 4, 2026
abc2634
Add debug output to copy_contiguous
NripeshN Feb 4, 2026
833bfc7
Fix const cast in debug output
NripeshN Feb 4, 2026
f10845a
Add more debug output to copy_contiguous
NripeshN Feb 4, 2026
f2f976b
Add stream sync before kernel launch
NripeshN Feb 4, 2026
9426d6c
Use hipMemcpy for small copies
NripeshN Feb 4, 2026
94868fa
Revert to simple kernel launch
NripeshN Feb 4, 2026
3990c3d
Remove debug output from copy_contiguous
NripeshN Feb 4, 2026
a74e904
Fix WARP_SIZE mismatch between host and device code
NripeshN Feb 4, 2026
9a05cd0
Refactor all_reduce to support all types using dispatch_all_types
NripeshN Feb 4, 2026
474f921
Fix all_reduce type casting for And/Or operations
NripeshN Feb 4, 2026
700de96
Add is_valid_reduce_op check to skip invalid type/op combinations
NripeshN Feb 4, 2026
5a9b067
Add complex type support for Min/Max reduce operations
NripeshN Feb 4, 2026
e2c5fcd
Add complex type support to reduce.hpp operators
NripeshN Feb 4, 2026
1766e04
Use SFINAE instead of if constexpr for complex type handling in reduc…
NripeshN Feb 4, 2026
af0acd6
Add complex type support for unary operations
NripeshN Feb 4, 2026
d655bbe
Include hip_complex.h in fp16_math.hpp for hipFloatComplex type
NripeshN Feb 4, 2026
d33bd4c
Refactor unary ops to use dispatch_all_types with type checking
NripeshN Feb 4, 2026
59e8097
Handle -inf case in complex exp function
NripeshN Feb 4, 2026
363b7eb
Add float16 and bfloat16 support to arange
NripeshN Feb 4, 2026
edb9cd7
Fix GPU architecture string in JIT module - gcnArchName already conta…
NripeshN Feb 4, 2026
f2a7f4f
Replace hip/std/array with simple array implementation for JIT
NripeshN Feb 4, 2026
31093f5
Add standard type definitions for JIT compilation
NripeshN Feb 4, 2026
6cf9a3f
Add missing unary and binary ops to JIT includes
NripeshN Feb 4, 2026
3082c41
Add uint16_t, int16_t, uint8_t, int8_t type definitions for JIT
NripeshN Feb 4, 2026
5f1a4d4
Add complex64 support to binary_op_gpu_inplace
NripeshN Feb 4, 2026
0c7e7ea
Add if constexpr check for supports_binary_op in launch_kernel
NripeshN Feb 4, 2026
5d0deba
Fix supports_binary_op for comparison operators with complex types
NripeshN Feb 4, 2026
f61797f
Remove complex64 from binary_op_gpu_inplace (not all ops support it)
NripeshN Feb 4, 2026
6870081
Fix supports_binary_op to use else if constexpr chain
NripeshN Feb 4, 2026
eed4267
Remove if constexpr check from launch_kernel (was causing issues)
NripeshN Feb 4, 2026
cef0bbc
Enhance ROCm backend with general binary operation support and improv…
NripeshN Feb 5, 2026
49c1dce
Enhance ROCm backend with dynamic memory management and kernel optimi…
NripeshN Feb 5, 2026
1fa3a44
Enhance ROCm backend with dynamic memory initialization and kernel ar…
NripeshN Feb 5, 2026
8a21489
Enhance ROCm backend with new all-reduce functionality and kernel opt…
NripeshN Feb 5, 2026
780a83d
Remove input dilation check from gemm_conv function in ROCm backend t…
NripeshN Feb 5, 2026
c40fd68
Refactor ROCm backend gather and scatter operations for improved perf…
NripeshN Feb 6, 2026
5993979
lint
NripeshN Feb 6, 2026
436b65d
Add hip_kernel support for ROCm backend and enhance Python bindings
NripeshN Feb 6, 2026
d6019c0
Enhance row_reduce function in ROCm backend to support contiguous data
NripeshN Feb 6, 2026
3be5a10
Remove unused type traits from ROCm unary kernel implementation to st…
NripeshN Feb 6, 2026
7672448
Implement single position RoPE kernel in ROCm backend
NripeshN Feb 7, 2026
b4a2a36
Refactor warp reduction logic in ROCm layer and RMS normalization ker…
NripeshN Feb 7, 2026
c550158
Add support for bfloat16 data type in scaled dot product attention ke…
NripeshN Feb 7, 2026
16c1ef4
Disable ROCm SDPA kernel due to warp size incompatibility
NripeshN Feb 7, 2026
f5aac8d
Rewrite ROCm SDPA kernel to be warp-size agnostic
NripeshN Feb 7, 2026
a6bf8cb
Temporarily disable ROCm SDPA kernel to debug memory fault
NripeshN Feb 7, 2026
af26ee9
Re-enable warp-agnostic ROCm SDPA kernel
NripeshN Feb 7, 2026
c6d9a92
ci trigger
NripeshN Feb 8, 2026
9d73b71
Added github workflow for rocm strix halo
goniz Jan 27, 2026
2285120
Fix ROCm bfloat16 matmul and kernel type handling
goniz Feb 25, 2026
0a08672
Fix ROCm non-uniform batched matmul for fp16/bfloat16
goniz Feb 25, 2026
3a9c39b
Fix ROCm affine quantized matmul sign handling
goniz Feb 25, 2026
8684c46
Fix ROCm non-power-of-two quantized packing
goniz Feb 25, 2026
fb3a67e
Replace Qwen3 smoke script with pytest suite
goniz Feb 25, 2026
a01a7bd
Merge original MLX main into rocm-support-fixes
goniz Feb 25, 2026
8dec0d4
Fix ROCm LogAddExp bf16 handling and expand generation matrix
goniz Feb 25, 2026
9c8718d
Fix ROCm GatherQMM index contiguity
goniz Feb 25, 2026
ac27e78
Support strided GatherQMM indices on ROCm
goniz Feb 25, 2026
11b2920
Fix ROCm hot-path pointer access to avoid host synchronization
goniz Feb 25, 2026
4758c15
Accelerate ROCm depthwise Conv1d grouped path
goniz Feb 25, 2026
1e7e977
Fix ROCm GatherMM hard sync in fallback path
goniz Feb 25, 2026
cbcd332
Fix ROCm BLAS pytest failures in direct test runs
goniz Feb 25, 2026
f3a30e0
Implement ROCm MaskedScatter kernel for boolean indexing
goniz Feb 25, 2026
926fdee
Fix ROCm SDPA crashes in GQA causal paths
goniz Feb 25, 2026
1d95664
Fix ROCm fp quantized matmul decode paths
goniz Feb 26, 2026
b5c0ba3
Fix ROCm quantized fallback paths for fp and qqmm
goniz Feb 26, 2026
77320af
Accelerate ROCm quantized decode path for generation
goniz Feb 26, 2026
9d23561
Optimize ROCm quantized matmul decode kernels
goniz Feb 26, 2026
04805fd
Optimize ROCm GatherQMM warp decode path
goniz Feb 26, 2026
fed4ca0
Tune ROCm quantized warp kernels for decode throughput
goniz Feb 26, 2026
0618c69
Tune ROCm 8-bit quantized decode kernels
goniz Feb 26, 2026
ff3fcfc
Tune ROCm quantized subgroup threading for decode
goniz Feb 26, 2026
43cd9dc
Optimize ROCm GEMV batched launch parameter handling
goniz Feb 26, 2026
2f5964f
Fix ROCm gather GEMV indexing for batched layouts
goniz Feb 26, 2026
698f86c
Optimize ROCm APU allocator and fix high CPU spin-wait
goniz Feb 27, 2026
17b7cb8
Add bfloat16 support for rocBLAS GEMM operations
goniz Feb 27, 2026
f29e4e4
Optimize ROCm GEMV with vectorized loads and wider n_per_thread
goniz Feb 27, 2026
a6967d2
Increase ROCm max ops per buffer from 20 to 1000
goniz Feb 27, 2026
8c56f29
Fix quantized matmul array creation bug and simplify kernels
goniz Feb 27, 2026
197e844
Merge branch 'main' of github.com:ml-explore/mlx into rocm-support-fixes
goniz Feb 27, 2026
a1a642e
Optimize ROCm backend: Fix SDPA fallback, enable QMM rocBLAS dequant,…
goniz Feb 28, 2026
719dc9d
Add optimized Flash Attention and reduce rocBLAS dispatch overhead
goniz Feb 28, 2026
0c5144a
ROCm: Add MLA Flash Attention support and fix rocBLAS dispatch
goniz Mar 1, 2026
7d5eb69
benchmark: update default max-tokens to 1000
goniz Mar 1, 2026
e8e3a45
benchmark: remove --no-warmup from llama-completion
goniz Mar 1, 2026
958240a
benchmark: redact prompt from logs to reduce terminal clutter
goniz Mar 1, 2026
d55d2a2
ROCm: Fix JIT compilation 'File name too long' error
goniz Mar 1, 2026
805d272
ROCm: Add math function overloads for bfloat16 and half types
goniz Mar 1, 2026
b44396a
ROCm: Fix quantized GEMM fallback correctness
goniz Mar 1, 2026
f1687cc
ROCm: fix 5/6-bit affine quantized matmul page faults
goniz Mar 1, 2026
108195a
ROCm: Fix quantized matmul with singleton batch dimensions
goniz Mar 1, 2026
ec84dfd
ROCm: Optimize quantized matmul and MoE gather for decode shapes
goniz Mar 2, 2026
f4634b4
ROCm: Vectorize 4-bit and 6-bit memory access in qmv_warp_shared_kernel
goniz Mar 2, 2026
a69c471
ROCm: Set default THREADS_PER_COL to 16 for qmv warp kernels
goniz Mar 2, 2026
24ecc76
ROCm: Optimize RoPE kernel for decode with sincosf and 1D layout
goniz Mar 2, 2026
4353b1b
ROCm: vectorize 6-bit fallback QMV kernels
goniz Mar 2, 2026
b811a89
ROCm: optimize QMM dispatch and extend SDPA head-dim support
goniz Mar 3, 2026
b38695f
ROCm: harden QMM cache keys and tune QMV launch defaults
goniz Mar 3, 2026
bc3bd38
ROCm: improve SDPA decode dispatch and avoid AddMM copy
goniz Mar 3, 2026
2884e85
ROCm: broaden batched GEMM fast-path stride detection
goniz Mar 3, 2026
7c80030
ROCm: add configurable rocBLAS GEMM solution-index dispatch
goniz Mar 3, 2026
184ef21
ROCm: make QMV launch defaults shape-adaptive
goniz Mar 3, 2026
c6883ca
ROCm: increase shared QMV tile size for decode
goniz Mar 3, 2026
d5d8b31
ROCm: reduce command-encoder scheduling overhead
goniz Mar 3, 2026
7bca990
ROCm: add sorted-rhs gather scheduling fast path
goniz Mar 3, 2026
20bcdd2
ROCm: extend sorted-rhs gather schedule across QMV dispatch
goniz Mar 3, 2026
d07f6a5
Benchmarks: route Qwen3.5 vision models through mlx-vlm
goniz Mar 3, 2026
1c93a6f
ROCm: add architecture-aware QMV crossover and tiny-K dispatch
goniz Mar 3, 2026
6be6435
ROCm: add alignment-aware QMV variant selection
goniz Mar 3, 2026
3ca29dc
ROCm: fix no-shared QMV accumulator shadowing
goniz Mar 3, 2026
879a200
Merge branch 'main' of github.com:ml-explore/mlx into rocm-support-fixes
goniz Mar 3, 2026
9193df5
Merge NripeshN/mlx rocm-support into upstream main
Geramy Mar 25, 2026
9fddf1c
Add RDNA 3.5/4 architectures and parallel HIP compilation
Geramy Mar 25, 2026
3ae44dc
Fix parallel-jobs flag: single dash for hipcc/clang
Geramy Mar 25, 2026
2b8a7d1
Limit HIP parallel-jobs to half of available CPUs
Geramy Mar 25, 2026
c2eb919
Add missing gpu::init() and SliceUpdate::eval_gpu stub for ROCm
Geramy Mar 25, 2026
26e733c
Implement ROCm-optimized SliceUpdate::eval_gpu
Geramy Mar 25, 2026
edd89a1
Fix bfloat16/half JIT compilation for ROCm fused kernels
Geramy Mar 25, 2026
1ab4186
Simplify JIT preamble ops: always promote through float
Geramy Mar 25, 2026
d03fa7c
Fix critical bug: JIT KernelArgs passed CPU pointers instead of GPU
Geramy Mar 25, 2026
76741bc
Remove gfx1150/1151/1152/1200/1201 from rocBLAS supported list
Geramy Mar 25, 2026
9336df8
Add rocBLAS fallback to naive_gemm when Tensile kernel missing
Geramy Mar 25, 2026
f92d2d2
Add missing kernel_utils.hpp include for gpu_ptr in rocblas_gemm
Geramy Mar 25, 2026
8acadb4
Probe rocBLAS bf16 GEMM at device init, fallback to naive_gemm
Geramy Mar 25, 2026
bfab6fb
Always use naive_gemm for bfloat16 GEMM on ROCm
Geramy Mar 25, 2026
c8c9c8e
ROCm bug fixes + optimized quantized GEMV kernel
Geramy Mar 26, 2026
2f47aeb
Promote JIT binary ops through float, restore rocBLAS for gfx1151
Geramy Mar 26, 2026
6520667
GatherQMM: ensure contiguous indices, SDPA: add head_dim=256
Geramy Mar 26, 2026
00d8c2e
SDPA GPU decomposition, naive_gemm for all types, GatherQMM contiguou…
Geramy Mar 26, 2026
4a5bb0f
Metal-compatible QMM accumulation, JIT stderr suppression
Geramy Mar 26, 2026
73470d8
Fix GatherQMM memory corruption, add index bounds clamping
Geramy Mar 26, 2026
1e50c74
Kernel audit: match Metal precision across RMSNorm, sort, softmax, ops
Geramy Mar 26, 2026
1793485
Fix batched matmul: missing bfloat16/float16 in loop-based GQA path
Geramy Mar 27, 2026
840d028
Add head_dim=256 dispatch to SDPA vector kernel
Geramy Mar 27, 2026
b48adae
Merge upstream main into rocm-support
NripeshN Mar 27, 2026
fe75135
Merge goniz/rocm-support-fixes with extensive ROCm optimizations
NripeshN Mar 27, 2026
d30fe29
Merge upstream NripeshN/mlx rocm-support with ROCm optimizations
Geramy Mar 27, 2026
5ffb863
Enable 4-bit fast gather QMV dispatch for MoE decode
Geramy Mar 27, 2026
b1300b9
Optimize ROCm allocator for integrated GPUs (APU)
Geramy Mar 27, 2026
780b4fe
Prefer shared-memory QMV over noshared variant for decode
Geramy Mar 27, 2026
0ec6b45
Add expert-grouped prefill kernel for GatherQMM (3.4x prompt speedup)
Geramy Mar 27, 2026
c9167d2
Allocator: prefer hipExtMallocWithFlags for APU, fallback to hipMallo…
Geramy Mar 27, 2026
a66e273
Add WMMA-accelerated prefill kernel for GatherQMM on RDNA 3/3.5/4
Geramy Mar 27, 2026
e35d6aa
WMMA prefill kernel: support non-aligned M, sort unsorted indices
Geramy Mar 27, 2026
435afdc
Add GPU-only expert-batched gather QMV kernel for low-expert MoE
Geramy Mar 27, 2026
bc4d62f
Add hipBLASLt GEMM integration for bf16/fp16 matmul on ROCm
Geramy Mar 27, 2026
b8b56b1
hipBLASLt: add to QMM dequant+GEMM path for bf16 (2.6x prompt speedup)
Geramy Mar 27, 2026
7ac6efd
hipBLASLt in QMM dequant path + CommandEncoder graph capture API
Geramy Mar 27, 2026
b913c68
Strided copy kernels for ensure_row_contiguous in QMM
Geramy Mar 27, 2026
da1925b
Allocator: power-of-2 rounding for large allocs (>= 1MB)
Geramy Mar 28, 2026
65958fa
Allocator: use system RAM limit for iGPU, power-of-2 rounding for lar…
Geramy Mar 28, 2026
b010eee
Allocator: revert power-of-2 rounding, keep hipExtMallocWithFlags
Geramy Mar 28, 2026
f26c802
Fix CU count comment: 40 CUs (20 WGPs) on gfx1151
Geramy Mar 28, 2026
251c8d8
Merge pull request #5 from lemonade-sdk/rocm-optimizations
Geramy Mar 30, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions .github/workflows/build_rocm.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
name: Build ROCm and Test

on:
push:
branches: [ rocm-support ]
workflow_dispatch:

jobs:
build-and-test:
runs-on: strix-halo

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
run: |
uv venv venv
source venv/bin/activate
uv pip install --upgrade mlx-lm

- name: Build and install MLX ROCm wheel
run: |
source venv/bin/activate
export CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES=gfx1151 -DBLA_VENDOR=OpenBLAS -DCMAKE_BUILD_TYPE=RelWithDebInfo"
rm -rf wheelhouse
mkdir -p wheelhouse
uv build --wheel --out-dir wheelhouse .
uv pip install --force-reinstall wheelhouse/mlx-*.whl

- name: Basic MLX GPU test
run: |
source venv/bin/activate
python3 -c "
import mlx.core as mx
print('MLX version:', mx.__version__)
print('Default device:', mx.default_device())
mx.set_default_device(mx.gpu)
print('GPU device set')

# Test basic operations
a = mx.ones((10, 10))
mx.eval(a)
print('Basic array creation: OK')

# Test matmul
b = mx.random.normal((256, 256))
c = mx.matmul(b, b)
mx.eval(c)
print('Matmul test: OK')

# Test softmax
d = mx.softmax(b, axis=-1)
mx.eval(d)
print('Softmax test: OK')

print('All basic tests passed!')
"

- name: Run inference tests
run: |
source venv/bin/activate
export HIP_LAUNCH_BLOCKING=1
export PYTHONFAULTHANDLER=1
mkdir -p "${GITHUB_WORKSPACE}/rocm-stacktraces"

run_and_trace() {
local name="$1"
shift
lldb -Q -b \
-o "run" \
-k "bt" \
-k "quit 1" \
-- python3 "$(which mlx_lm.generate)" "$@" \
> >(tee "${GITHUB_WORKSPACE}/rocm-stacktraces/${name}.log") 2>&1
}

run_and_trace qwen3_bf16 --model mlx-community/Qwen3-0.6B-bf16 --prompt "Hi" --max-tokens 5
run_and_trace qwen3_8bit --model mlx-community/Qwen3-0.6B-8bit --prompt "How tall is Mt Everest?" --max-tokens 128

- name: Upload ROCm wheel artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
with:
name: rocm-wheel-${{ github.run_attempt }}
path: wheelhouse/mlx-*.whl
if-no-files-found: warn
retention-days: 14

- name: Upload ROCm stacktrace artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
with:
name: rocm-stacktraces-${{ github.run_attempt }}
path: ${{ github.workspace }}/rocm-stacktraces/*
if-no-files-found: warn
retention-days: 14
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,10 @@ uv.lock
.cache/
# vim
*.swp

# keys
*.pem

build.sh
github-runner/
sync_fork.sh
44 changes: 42 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_BUILD_ROCM "Build rocm backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
Expand Down Expand Up @@ -162,6 +163,43 @@ if(MLX_BUILD_CUDA)
endif()
endif()

if(MLX_BUILD_ROCM)
# Set HIP architectures - these will be used by the ROCm backend
# CMakeLists.txt
#
# Supported architectures from ROCm 6.4.0 - 7.2.0 compatibility matrix: CDNA:
# gfx908 (MI100), gfx90a (MI200), gfx942 (MI300) CDNA4: gfx950 (MI400 series)
# RDNA2: gfx1030 (RX 6000 series) RDNA3: gfx1100 (RX 7900), gfx1101 (RX 7600)
# RDNA4: gfx1200, gfx1201 (RX 8000 series)
if(NOT DEFINED CMAKE_HIP_ARCHITECTURES)
if(DEFINED MLX_ROCM_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES
${MLX_ROCM_ARCHITECTURES}
CACHE STRING "HIP architectures")
else()
set(CMAKE_HIP_ARCHITECTURES
"gfx908;gfx90a;gfx942;gfx1010;gfx1011;gfx1012;gfx1030;gfx1031;gfx1032;gfx1100;gfx1101;gfx1102"
CACHE STRING "HIP architectures")
endif()
endif()
message(
STATUS "Setting CMAKE_HIP_ARCHITECTURES to: ${CMAKE_HIP_ARCHITECTURES}")
# Note: We don't enable_language(HIP) here because it causes CMake to add -x
# hip to all CXX files in targets that link to HIP libraries. Instead, we
# compile HIP files using custom commands in the ROCm backend CMakeLists.txt.
# Find the HIP compiler
find_program(
CMAKE_HIP_COMPILER
NAMES hipcc clang++
PATHS /opt/rocm/bin /opt/rocm-6.0.0/bin /opt/rocm/llvm/bin
PATH_SUFFIXES bin
DOC "HIP compiler")
if(NOT CMAKE_HIP_COMPILER)
message(FATAL_ERROR "Could not find HIP compiler (hipcc or clang++)")
endif()
message(STATUS "Found HIP compiler: ${CMAKE_HIP_COMPILER}")
endif()

if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
Expand Down Expand Up @@ -290,10 +328,12 @@ if(MLX_BUILD_CPU)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
/usr/local/opt/openblas/include)
/usr/local/opt/openblas/include /usr/include/openblas)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
if(LAPACK_INCLUDE_DIRS)
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
endif()
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas.
Expand Down
Loading