Skip to content

Fix/wan2.1 flash attention#153

Open
vadseshu wants to merge 2 commits intoROCm:developfrom
vadseshu:fix/wan2.1_flash_attention
Open

Fix/wan2.1 flash attention#153
vadseshu wants to merge 2 commits intoROCm:developfrom
vadseshu:fix/wan2.1_flash_attention

Conversation

@vadseshu
Copy link
Copy Markdown
Contributor

Motivation

Updated wan2.1 dockerfile with FA steps taken from ROCM FA repo.

Technical Details

Test Plan

Test Result

image

Submission Checklist

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates AMD inference Dockerfiles to adjust FlashAttention build/install behavior (notably for Wan2.1) and expands the supported ROCm arch list for Mochi.

Changes:

  • Replaces the pinned/parameterized FlashAttention wheel build in the Wan2.1 Dockerfile with a direct setup.py install from an unpinned ROCm/flash-attention clone.
  • Adds gfx950 to the PYTORCH_ROCM_ARCH list in the Mochi inference Dockerfile.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
docker/pyt_wan2.1_inference.ubuntu.amd.Dockerfile Changes FlashAttention installation steps for Wan2.1 image builds.
docker/pyt_mochi_inference.ubuntu.amd.Dockerfile Updates the ROCm architecture list used when building FlashAttention.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +67 to +87
#ARG BUILD_FA="1"
#ARG FA_BRANCH="v3.0.0.r1-cktile"
#ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
#RUN if [ "$BUILD_FA" = "1" ]; then \
# cd ${WORKSPACE_DIR} \
# && pip uninstall -y flash-attention \
# && rm -rf flash-attention \
# && git clone ${FA_REPO} \
# && cd flash-attention \
# && git checkout ${FA_BRANCH} \
# && git submodule update --init \
# && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \
# && pip install dist/*.whl \
# && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \
# fi
# install flash attention
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"

RUN git clone https://github.com/ROCm/flash-attention.git &&\
cd flash-attention &&\
python setup.py install
Comment on lines +67 to +87
#ARG BUILD_FA="1"
#ARG FA_BRANCH="v3.0.0.r1-cktile"
#ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
#RUN if [ "$BUILD_FA" = "1" ]; then \
# cd ${WORKSPACE_DIR} \
# && pip uninstall -y flash-attention \
# && rm -rf flash-attention \
# && git clone ${FA_REPO} \
# && cd flash-attention \
# && git checkout ${FA_BRANCH} \
# && git submodule update --init \
# && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \
# && pip install dist/*.whl \
# && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \
# fi
# install flash attention
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"

RUN git clone https://github.com/ROCm/flash-attention.git &&\
cd flash-attention &&\
python setup.py install
Comment on lines 35 to +36
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201
ARG PYTORCH_ROCM_ARCH=gfx950;gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201
# install flash attention
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"

RUN git clone https://github.com/ROCm/flash-attention.git &&\
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use FA_BRANCH & FA_REPO arguments. These are meant to build using build arguments with whatever branch is needed. Already existing branch is the latest tag from Flash-attention, any specific reason to remove it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants