Skip to content

Add AMD ROCm support: AITER attention backend + robust imports + docs#3

Open
ZJLi2013 wants to merge 1 commit intoH-EmbodVis:mainfrom
ZJLi2013:feat/rocm-flash-attention-support
Open

Add AMD ROCm support: AITER attention backend + robust imports + docs#3
ZJLi2013 wants to merge 1 commit intoH-EmbodVis:mainfrom
ZJLi2013:feat/rocm-flash-attention-support

Conversation

@ZJLi2013
Copy link
Copy Markdown

@ZJLi2013 ZJLi2013 commented Apr 1, 2026

Summary

Enable HyDRA to run efficiently on AMD GPUs (ROCm) with optimized attention backends.

Problem

The current attention import guards use ModuleNotFoundError, which misses partial import failures (e.g. AITER's eager top-level imports). When flash-attn is not installed, the code silently falls back to PyTorch SDPA with no diagnostic logging, making it hard to tell which backend is active.

Changes

diffsynth/models/wan_video_dit.py

  • Widen exception handling from ModuleNotFoundError(ImportError, ModuleNotFoundError) for flash_attn, flash_attn_interface, and sageattention imports
  • Add AMD AITER as an attention backend (via importlib to avoid eager import side-effects), slotted between FA3 and FA2 in the dispatch chain
  • Log selected attention backend at import time for easier debugging

README.md

  • Add collapsible AMD ROCm installation guide (PyTorch ROCm, FlashAttention Triton build, AITER CK backend)

Dispatch priority

FA3 → AITER → FA2 → SageAttention → PyTorch SDPA (fallback)

No behavior change for NVIDIA users — AITER is only available on ROCm and gracefully skipped otherwise.

Benchmarks (AMD MI300X, ROCm 6.4, FA2 Triton backend)

Metric Before (SDPA) After (FA2 Triton) Delta
Steady-state step time ~12.4 s/step ~10.0 s/step -19%
Total inference (4 samples) ~50 min ~43 min -13%
AITER with CK backend (ROCm 7.x) is expected to provide an additional ~25% speedup over Triton, based on benchmarks from other Wan2.1-based pipelines.

Test plan

  • Verified on AMD MI300X (gfx942) + ROCm 6.4 + PyTorch 2.9.1
  • FA2 Triton backend correctly detected and used (confirmed via log output)
  • Generated videos are visually consistent with SDPA baseline
  • NVIDIA GPU regression test (no AITER installed → should fall through to existing backends un
3_concat.mp4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant