diff --git a/docker/pyt_mochi_inference.ubuntu.amd.Dockerfile b/docker/pyt_mochi_inference.ubuntu.amd.Dockerfile index 4cbbd1e..fb9426f 100644 --- a/docker/pyt_mochi_inference.ubuntu.amd.Dockerfile +++ b/docker/pyt_mochi_inference.ubuntu.amd.Dockerfile @@ -33,7 +33,7 @@ RUN mkdir -p $WORKSPACE_DIR WORKDIR $WORKSPACE_DIR ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201 +ARG PYTORCH_ROCM_ARCH=gfx950;gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201 RUN git clone ${FA_REPO} RUN cd flash-attention \ && git submodule update --init \ diff --git a/docker/pyt_wan2.1_inference.ubuntu.amd.Dockerfile b/docker/pyt_wan2.1_inference.ubuntu.amd.Dockerfile index 55e0504..6a2dbb6 100644 --- a/docker/pyt_wan2.1_inference.ubuntu.amd.Dockerfile +++ b/docker/pyt_wan2.1_inference.ubuntu.amd.Dockerfile @@ -64,21 +64,29 @@ RUN apt-get install -y locales RUN locale-gen en_US.UTF-8 # Install flash attention -ARG BUILD_FA="1" -ARG FA_BRANCH="v3.0.0.r1-cktile" -ARG FA_REPO="https://github.com/ROCm/flash-attention.git" -RUN if [ "$BUILD_FA" = "1" ]; then \ - cd ${WORKSPACE_DIR} \ - && pip uninstall -y flash-attention \ - && rm -rf flash-attention \ - && git clone ${FA_REPO} \ - && cd flash-attention \ - && git checkout ${FA_BRANCH} \ - && git submodule update --init \ - && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \ - && pip install dist/*.whl \ - && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \ - fi +#ARG BUILD_FA="1" +#ARG FA_BRANCH="v3.0.0.r1-cktile" +#ARG FA_REPO="https://github.com/ROCm/flash-attention.git" +#RUN if [ "$BUILD_FA" = "1" ]; then \ +# cd ${WORKSPACE_DIR} \ +# && pip uninstall -y flash-attention \ +# && rm -rf flash-attention \ +# && git clone ${FA_REPO} \ +# && cd flash-attention \ +# && git checkout ${FA_BRANCH} \ +# && git submodule update --init \ +# && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \ +# && pip install dist/*.whl \ +# && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \ +# fi +# install flash attention +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + +RUN git clone https://github.com/ROCm/flash-attention.git &&\ + cd flash-attention &&\ + python setup.py install + + #Download wan2.1 source code RUN cd $WORKSPACE_DIR \