Skip to content

Fix/janus pro flash attention#152

Open
vadseshu wants to merge 2 commits intoROCm:developfrom
vadseshu:fix/janus_pro_flash_attention
Open

Fix/janus pro flash attention#152
vadseshu wants to merge 2 commits intoROCm:developfrom
vadseshu:fix/janus_pro_flash_attention

Conversation

@vadseshu
Copy link
Copy Markdown
Contributor

Motivation

Fixed FA compilation steps , as the steps are taken from https://github.com/dao-ailab/flash-attention ROCM support repo and fixed AttributeError: 'MultiModalityCausalLM' object has no attribute 'all_tied_weights_keys'. Did you mean: '_tied_weights_keys'?

Technical Details

Janus + Transformers 5.x — PretrainedConfig / dataclasses
Change: After clone, for both
janus/models/modeling_vlm.py and janus/janusflow/models/modeling_vlm.py:
add from dataclasses import field after the PretrainedConfig import;
replace params: AttrDict = {} with params: AttrDict = field(default_factory=dict).
Why: Transformers 5.x turns those configs into dataclasses; a plain {} is an illegal mutable default and caused:
ValueError: mutable default … for field params … use default_factory.
3. Janus + Transformers 5.x — from_pretrained / weight loading
Change: append self.post_init() at the end of MultiModalityCausalLM.init in:
Main model: right after self.language_model = LlamaForCausalLM(language_config).
Janus-Flow model: right after self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True).
Why: In Transformers 5.x, all_tied_weights_keys (and related setup) is applied in PreTrainedModel.post_init(). Janus never called it, which caused:
AttributeError: 'MultiModalityCausalLM' object has no attribute 'all_tied_weights_keys'.

Test Plan

Test Result

image

Submission Checklist

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates AMD Docker images to (a) adjust FlashAttention build settings/targets and (b) patch Janus-Pro at build time for compatibility with Hugging Face Transformers 5.x weight loading.

Changes:

  • Add gfx950 to the default PYTORCH_ROCM_ARCH list used when building flash-attention in the Mochi inference image.
  • Replace the Janus-Pro image’s flash-attention build logic with a new clone-and-install flow and enable FLASH_ATTENTION_TRITON_AMD_ENABLE.
  • Apply Janus source patches during the Docker build to address Transformers 5.x dataclass defaults and post_init()/tied-weights setup.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
docker/pyt_mochi_inference.ubuntu.amd.Dockerfile Extends the default ROCm arch list for flash-attention wheel build.
docker/pyt_janus_pro_inference.ubuntu.amd.Dockerfile Changes flash-attention installation approach and adds Docker-time Janus patches for Transformers 5.x.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +64 to +67

RUN git clone https://github.com/ROCm/flash-attention.git &&\
cd flash-attention &&\
python setup.py install
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flash-attention install is now unpinned (no branch/SHA) and skips git submodule update --init, and python setup.py install builds from the live HEAD. This makes builds non-reproducible and is likely to break (other Dockerfiles in this repo pin a ref and init submodules before building a wheel). Consider restoring pinning (FA_BRANCH/FA_SHA), initializing submodules, and installing via a built wheel (or pip install .) with explicit GPU_ARCHS=${HIP_ARCHITECTURES} to avoid long/incorrect builds.

Suggested change
RUN git clone https://github.com/ROCm/flash-attention.git &&\
cd flash-attention &&\
python setup.py install
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG FA_BRANCH="v3.0.0.r1-cktile"
ARG FA_SHA=""
RUN cd ${APP_DIR} \
&& rm -rf flash-attention \
&& git clone --branch ${FA_BRANCH} ${FA_REPO} flash-attention \
&& cd flash-attention \
&& if [ -n "${FA_SHA}" ]; then git checkout ${FA_SHA}; fi \
&& git submodule update --init --recursive \
&& GPU_ARCHS=${HIP_ARCHITECTURES} python3 -m pip wheel . --no-build-isolation --wheel-dir=dist \
&& pip install dist/*.whl \
&& python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"

Copilot uses AI. Check for mistakes.
Comment on lines 45 to +60
# Install flash attention
ARG BUILD_FA="1"
ARG FA_BRANCH="v3.0.0.r1-cktile"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
RUN if [ "$BUILD_FA" = "1" ]; then \
cd ${APP_DIR} \
&& pip uninstall -y flash-attention \
&& rm -rf flash-attention \
&& git clone ${FA_REPO} \
&& cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl \
&& python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \
fi
#ARG BUILD_FA="1"
#ARG FA_BRANCH="v3.0.0.r1-cktile"
#ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
#RUN if [ "$BUILD_FA" = "1" ]; then \
# cd ${APP_DIR} \
# && pip uninstall -y flash-attention \
# && rm -rf flash-attention \
# && git clone ${FA_REPO} \
# && cd flash-attention \
# && git checkout ${FA_BRANCH} \
# && git submodule update --init \
# && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \
# && pip install dist/*.whl \
# && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \
# fi
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The large commented-out flash-attention build block (BUILD_FA/FA_BRANCH/FA_REPO) looks like dead code after the new install path was added. To keep the Dockerfile maintainable, either remove the commented block or re-enable it behind a build ARG so there is a single authoritative installation path.

Copilot uses AI. Check for mistakes.
Comment on lines +78 to +81
&& sed -i 's/params: AttrDict = {}/params: AttrDict = field(default_factory=dict)/g' "$f"; \
done \
&& sed -i '/self\.language_model = LlamaForCausalLM(language_config)/a\ self.post_init()' janus/models/modeling_vlm.py \
&& sed -i '/self\.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)/a\ self.post_init()' janus/janusflow/models/modeling_vlm.py \
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Janus/transformers 5.x compatibility changes are applied via sed against specific source lines. sed will still exit successfully even if the patterns stop matching upstream, which can silently produce an unpatched image and reintroduce the runtime errors this PR is addressing. Consider pinning Janus to a known commit and/or adding post-sed assertions (e.g., grep for default_factory=dict and self.post_init()) so the build fails if the patch didn’t apply.

Suggested change
&& sed -i 's/params: AttrDict = {}/params: AttrDict = field(default_factory=dict)/g' "$f"; \
done \
&& sed -i '/self\.language_model = LlamaForCausalLM(language_config)/a\ self.post_init()' janus/models/modeling_vlm.py \
&& sed -i '/self\.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)/a\ self.post_init()' janus/janusflow/models/modeling_vlm.py \
&& sed -i 's/params: AttrDict = {}/params: AttrDict = field(default_factory=dict)/g' "$f" \
&& grep -q 'from dataclasses import field' "$f" \
&& grep -q 'params: AttrDict = field(default_factory=dict)' "$f"; \
done \
&& sed -i '/self\.language_model = LlamaForCausalLM(language_config)/a\ self.post_init()' janus/models/modeling_vlm.py \
&& grep -q 'self.post_init()' janus/models/modeling_vlm.py \
&& sed -i '/self\.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)/a\ self.post_init()' janus/janusflow/models/modeling_vlm.py \
&& grep -q 'self.post_init()' janus/janusflow/models/modeling_vlm.py \

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants