Skip to content

Add AMD ROCm/HIP support (2-line fix)#17

Closed
ZJLi2013 wants to merge 4 commits intoJeffreyXiang:mainfrom
PhysicalAI-AIM:rocm
Closed

Add AMD ROCm/HIP support (2-line fix)#17
ZJLi2013 wants to merge 4 commits intoJeffreyXiang:mainfrom
PhysicalAI-AIM:rocm

Conversation

@ZJLi2013
Copy link
Copy Markdown
Contributor

@ZJLi2013 ZJLi2013 commented Apr 8, 2026

Summary

Enable FlexGEMM to compile and run on AMD GPUs (ROCm/HIP) with only 2 lines of core code changed.

Changes

File Change Why
flex_gemm/kernels/cuda/spconv/migemm_neighmap_pp.cu #define __syncwarp(...) __builtin_amdgcn_wave_barrier() (guarded by __HIP_PLATFORM_AMD__) __syncwarp() (no-arg form) is not in HIP builtins. AMD wavefronts execute in SIMD lockstep, so wave_barrier is semantically equivalent.
flex_gemm/kernels/triton/spconv/config.py allow_tf32 = not torch.version.hip TF32 (TensorFloat-32) is NVIDIA Tensor Core specific. AMD CDNA matrix cores use ieee precision (fp32/fp16). The Triton kernels already have input_precision='tf32' if allow_tf32 else 'ieee' branching — this change just sets the flag correctly at config level.

Scope

Covered (tested and passing):

  • All 4 Triton-based algorithms: IMPLICIT_GEMM, IMPLICIT_GEMM_SPLITK, MASKED_IMPLICIT_GEMM, MASKED_IMPLICIT_GEMM_SPLITK
  • CUDA C++ utility modules (hashmap, serialize, neighbor map) — auto-hipified by setup.py
    Not tested:
  • EXPLICIT_GEMM (pure torch.mm / im2col path — expected to work via hipBLAS, but not explicitly verified)

Usage on AMD GPUs

# Build (setup.py auto-detects HIP and runs hipify)
pip install . --no-build-isolation
# Use — no code changes needed, default algorithm works out of the box
import flex_gemm
# Default: MASKED_IMPLICIT_GEMM_SPLITK (Triton, ROCm compatible)
# Explicitly set algorithm if needed
from flex_gemm.ops import spconv
spconv.set_algorithm(spconv.Algorithm.IMPLICIT_GEMM)           # OK
spconv.set_algorithm(spconv.Algorithm.MASKED_IMPLICIT_GEMM)    # OK (default)

Testing

Hardware: AMD Instinct MI300X
Software: ROCm 6.4.3, PyTorch 2.6.0, Triton 3.2.0
Docker: rocm/pytorch:rocm6.4.3_ubuntu24.04_py3.12_pytorch_release_2.6.0

Test Result
Hashmap insert/lookup (3D, idx_as_val) PASS
Serialize z-order encode/decode PASS
Serialize hilbert encode/decode PASS
Sparse SubM Conv3d Forward (IMPLICIT_GEMM) PASS
Sparse SubM Conv3d Backward (IMPLICIT_GEMM) PASS
Masked Implicit GEMM Forward + Backward PASS
Masked Implicit GEMM SplitK PASS

Notes

Complementary to PR #15 (which added a missing __syncwarp in neighbor_map.cu for NVIDIA correctness). This PR fixes a different file (migemm_neighmap_pp.cu) for AMD/HIP portability.
setup.py already has robust HIP support (IS_HIP_EXTENSION, --offload-arch). hipify auto-converted 100% of CUDA calls (0 unsupported). No build system changes needed.
warpSize 32 vs 64 difference caused zero failures — the existing adaptive logic in reduce_code_kernel handles 64-wide wavefronts correctly.

@ZJLi2013 ZJLi2013 closed this Apr 8, 2026
@ZJLi2013 ZJLi2013 deleted the rocm branch April 8, 2026 02:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant