diff --git a/docker/pyt_janus_pro_inference.ubuntu.amd.Dockerfile b/docker/pyt_janus_pro_inference.ubuntu.amd.Dockerfile index 0a9b6c8..c958741 100644 --- a/docker/pyt_janus_pro_inference.ubuntu.amd.Dockerfile +++ b/docker/pyt_janus_pro_inference.ubuntu.amd.Dockerfile @@ -43,27 +43,42 @@ ENV HIP_ARCHITECTURES=${MAD_SYSTEM_GPU_ARCHITECTURE} RUN echo HIP_ARCHITECTURES = ${HIP_ARCHITECTURES} # Install flash attention -ARG BUILD_FA="1" -ARG FA_BRANCH="v3.0.0.r1-cktile" -ARG FA_REPO="https://github.com/ROCm/flash-attention.git" -RUN if [ "$BUILD_FA" = "1" ]; then \ - cd ${APP_DIR} \ - && pip uninstall -y flash-attention \ - && rm -rf flash-attention \ - && git clone ${FA_REPO} \ - && cd flash-attention \ - && git checkout ${FA_BRANCH} \ - && git submodule update --init \ - && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \ - && pip install dist/*.whl \ - && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \ - fi +#ARG BUILD_FA="1" +#ARG FA_BRANCH="v3.0.0.r1-cktile" +#ARG FA_REPO="https://github.com/ROCm/flash-attention.git" +#RUN if [ "$BUILD_FA" = "1" ]; then \ +# cd ${APP_DIR} \ +# && pip uninstall -y flash-attention \ +# && rm -rf flash-attention \ +# && git clone ${FA_REPO} \ +# && cd flash-attention \ +# && git checkout ${FA_BRANCH} \ +# && git submodule update --init \ +# && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \ +# && pip install dist/*.whl \ +# && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \ +# fi -# Install Janus-pro +# install flash attention +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + +RUN git clone https://github.com/ROCm/flash-attention.git &&\ + cd flash-attention &&\ + python setup.py install + +# Install Janus-pro (patches for Hugging Face transformers 5.x): +# - @dataclass rejects mutable default dict on PretrainedConfig subclasses; use default_factory +# - PreTrainedModel.post_init() sets all_tied_weights_keys (required by from_pretrained / loading) RUN cd ${APP_DIR} \ && git clone https://github.com/deepseek-ai/Janus.git \ && cd Janus \ && sed -i 's/^torch==/# torch==/' requirements.txt \ + && for f in janus/models/modeling_vlm.py janus/janusflow/models/modeling_vlm.py; do \ + sed -i '/^from transformers.configuration_utils import PretrainedConfig/a from dataclasses import field' "$f" \ + && sed -i 's/params: AttrDict = {}/params: AttrDict = field(default_factory=dict)/g' "$f"; \ + done \ + && sed -i '/self\.language_model = LlamaForCausalLM(language_config)/a\ self.post_init()' janus/models/modeling_vlm.py \ + && sed -i '/self\.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)/a\ self.post_init()' janus/janusflow/models/modeling_vlm.py \ && pip install -e . \ && pip install datasets \ && cd .. diff --git a/docker/pyt_mochi_inference.ubuntu.amd.Dockerfile b/docker/pyt_mochi_inference.ubuntu.amd.Dockerfile index 4cbbd1e..fb9426f 100644 --- a/docker/pyt_mochi_inference.ubuntu.amd.Dockerfile +++ b/docker/pyt_mochi_inference.ubuntu.amd.Dockerfile @@ -33,7 +33,7 @@ RUN mkdir -p $WORKSPACE_DIR WORKDIR $WORKSPACE_DIR ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201 +ARG PYTORCH_ROCM_ARCH=gfx950;gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201 RUN git clone ${FA_REPO} RUN cd flash-attention \ && git submodule update --init \