Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/pyt_mochi_inference.ubuntu.amd.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ RUN mkdir -p $WORKSPACE_DIR
WORKDIR $WORKSPACE_DIR

ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201
ARG PYTORCH_ROCM_ARCH=gfx950;gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201
Comment on lines 35 to +36
RUN git clone ${FA_REPO}
RUN cd flash-attention \
&& git submodule update --init \
Expand Down
38 changes: 23 additions & 15 deletions docker/pyt_wan2.1_inference.ubuntu.amd.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,29 @@ RUN apt-get install -y locales
RUN locale-gen en_US.UTF-8

# Install flash attention
ARG BUILD_FA="1"
ARG FA_BRANCH="v3.0.0.r1-cktile"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
RUN if [ "$BUILD_FA" = "1" ]; then \
cd ${WORKSPACE_DIR} \
&& pip uninstall -y flash-attention \
&& rm -rf flash-attention \
&& git clone ${FA_REPO} \
&& cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl \
&& python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \
fi
#ARG BUILD_FA="1"
#ARG FA_BRANCH="v3.0.0.r1-cktile"
#ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
#RUN if [ "$BUILD_FA" = "1" ]; then \
# cd ${WORKSPACE_DIR} \
# && pip uninstall -y flash-attention \
# && rm -rf flash-attention \
# && git clone ${FA_REPO} \
# && cd flash-attention \
# && git checkout ${FA_BRANCH} \
# && git submodule update --init \
# && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \
# && pip install dist/*.whl \
# && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \
# fi
# install flash attention
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"

RUN git clone https://github.com/ROCm/flash-attention.git &&\
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use FA_BRANCH & FA_REPO arguments. These are meant to build using build arguments with whatever branch is needed. Already existing branch is the latest tag from Flash-attention, any specific reason to remove it?

cd flash-attention &&\
python setup.py install
Comment on lines +67 to +87
Comment on lines +67 to +87



#Download wan2.1 source code
RUN cd $WORKSPACE_DIR \
Expand Down