-
Notifications
You must be signed in to change notification settings - Fork 45
Fix/janus pro flash attention #152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -43,27 +43,42 @@ ENV HIP_ARCHITECTURES=${MAD_SYSTEM_GPU_ARCHITECTURE} | |||||||||||||||||||||||||||||||||||
| RUN echo HIP_ARCHITECTURES = ${HIP_ARCHITECTURES} | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Install flash attention | ||||||||||||||||||||||||||||||||||||
| ARG BUILD_FA="1" | ||||||||||||||||||||||||||||||||||||
| ARG FA_BRANCH="v3.0.0.r1-cktile" | ||||||||||||||||||||||||||||||||||||
| ARG FA_REPO="https://github.com/ROCm/flash-attention.git" | ||||||||||||||||||||||||||||||||||||
| RUN if [ "$BUILD_FA" = "1" ]; then \ | ||||||||||||||||||||||||||||||||||||
| cd ${APP_DIR} \ | ||||||||||||||||||||||||||||||||||||
| && pip uninstall -y flash-attention \ | ||||||||||||||||||||||||||||||||||||
| && rm -rf flash-attention \ | ||||||||||||||||||||||||||||||||||||
| && git clone ${FA_REPO} \ | ||||||||||||||||||||||||||||||||||||
| && cd flash-attention \ | ||||||||||||||||||||||||||||||||||||
| && git checkout ${FA_BRANCH} \ | ||||||||||||||||||||||||||||||||||||
| && git submodule update --init \ | ||||||||||||||||||||||||||||||||||||
| && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \ | ||||||||||||||||||||||||||||||||||||
| && pip install dist/*.whl \ | ||||||||||||||||||||||||||||||||||||
| && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \ | ||||||||||||||||||||||||||||||||||||
| fi | ||||||||||||||||||||||||||||||||||||
| #ARG BUILD_FA="1" | ||||||||||||||||||||||||||||||||||||
| #ARG FA_BRANCH="v3.0.0.r1-cktile" | ||||||||||||||||||||||||||||||||||||
| #ARG FA_REPO="https://github.com/ROCm/flash-attention.git" | ||||||||||||||||||||||||||||||||||||
| #RUN if [ "$BUILD_FA" = "1" ]; then \ | ||||||||||||||||||||||||||||||||||||
| # cd ${APP_DIR} \ | ||||||||||||||||||||||||||||||||||||
| # && pip uninstall -y flash-attention \ | ||||||||||||||||||||||||||||||||||||
| # && rm -rf flash-attention \ | ||||||||||||||||||||||||||||||||||||
| # && git clone ${FA_REPO} \ | ||||||||||||||||||||||||||||||||||||
| # && cd flash-attention \ | ||||||||||||||||||||||||||||||||||||
| # && git checkout ${FA_BRANCH} \ | ||||||||||||||||||||||||||||||||||||
| # && git submodule update --init \ | ||||||||||||||||||||||||||||||||||||
| # && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \ | ||||||||||||||||||||||||||||||||||||
| # && pip install dist/*.whl \ | ||||||||||||||||||||||||||||||||||||
| # && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \ | ||||||||||||||||||||||||||||||||||||
| # fi | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Install Janus-pro | ||||||||||||||||||||||||||||||||||||
| # install flash attention | ||||||||||||||||||||||||||||||||||||
| ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| RUN git clone https://github.com/ROCm/flash-attention.git &&\ | ||||||||||||||||||||||||||||||||||||
| cd flash-attention &&\ | ||||||||||||||||||||||||||||||||||||
| python setup.py install | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+64
to
+67
|
||||||||||||||||||||||||||||||||||||
| RUN git clone https://github.com/ROCm/flash-attention.git &&\ | |
| cd flash-attention &&\ | |
| python setup.py install | |
| ARG FA_REPO="https://github.com/ROCm/flash-attention.git" | |
| ARG FA_BRANCH="v3.0.0.r1-cktile" | |
| ARG FA_SHA="" | |
| RUN cd ${APP_DIR} \ | |
| && rm -rf flash-attention \ | |
| && git clone --branch ${FA_BRANCH} ${FA_REPO} flash-attention \ | |
| && cd flash-attention \ | |
| && if [ -n "${FA_SHA}" ]; then git checkout ${FA_SHA}; fi \ | |
| && git submodule update --init --recursive \ | |
| && GPU_ARCHS=${HIP_ARCHITECTURES} python3 -m pip wheel . --no-build-isolation --wheel-dir=dist \ | |
| && pip install dist/*.whl \ | |
| && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')" |
Copilot
AI
Apr 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Janus/transformers 5.x compatibility changes are applied via sed against specific source lines. sed will still exit successfully even if the patterns stop matching upstream, which can silently produce an unpatched image and reintroduce the runtime errors this PR is addressing. Consider pinning Janus to a known commit and/or adding post-sed assertions (e.g., grep for default_factory=dict and self.post_init()) so the build fails if the patch didn’t apply.
| && sed -i 's/params: AttrDict = {}/params: AttrDict = field(default_factory=dict)/g' "$f"; \ | |
| done \ | |
| && sed -i '/self\.language_model = LlamaForCausalLM(language_config)/a\ self.post_init()' janus/models/modeling_vlm.py \ | |
| && sed -i '/self\.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)/a\ self.post_init()' janus/janusflow/models/modeling_vlm.py \ | |
| && sed -i 's/params: AttrDict = {}/params: AttrDict = field(default_factory=dict)/g' "$f" \ | |
| && grep -q 'from dataclasses import field' "$f" \ | |
| && grep -q 'params: AttrDict = field(default_factory=dict)' "$f"; \ | |
| done \ | |
| && sed -i '/self\.language_model = LlamaForCausalLM(language_config)/a\ self.post_init()' janus/models/modeling_vlm.py \ | |
| && grep -q 'self.post_init()' janus/models/modeling_vlm.py \ | |
| && sed -i '/self\.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)/a\ self.post_init()' janus/janusflow/models/modeling_vlm.py \ | |
| && grep -q 'self.post_init()' janus/janusflow/models/modeling_vlm.py \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The large commented-out flash-attention build block (BUILD_FA/FA_BRANCH/FA_REPO) looks like dead code after the new install path was added. To keep the Dockerfile maintainable, either remove the commented block or re-enable it behind a build ARG so there is a single authoritative installation path.