Conversation
There was a problem hiding this comment.
Pull request overview
Updates AMD Docker images to (a) adjust FlashAttention build settings/targets and (b) patch Janus-Pro at build time for compatibility with Hugging Face Transformers 5.x weight loading.
Changes:
- Add
gfx950to the defaultPYTORCH_ROCM_ARCHlist used when building flash-attention in the Mochi inference image. - Replace the Janus-Pro image’s flash-attention build logic with a new clone-and-install flow and enable
FLASH_ATTENTION_TRITON_AMD_ENABLE. - Apply Janus source patches during the Docker build to address Transformers 5.x dataclass defaults and
post_init()/tied-weights setup.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| docker/pyt_mochi_inference.ubuntu.amd.Dockerfile | Extends the default ROCm arch list for flash-attention wheel build. |
| docker/pyt_janus_pro_inference.ubuntu.amd.Dockerfile | Changes flash-attention installation approach and adds Docker-time Janus patches for Transformers 5.x. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| RUN git clone https://github.com/ROCm/flash-attention.git &&\ | ||
| cd flash-attention &&\ | ||
| python setup.py install |
There was a problem hiding this comment.
The flash-attention install is now unpinned (no branch/SHA) and skips git submodule update --init, and python setup.py install builds from the live HEAD. This makes builds non-reproducible and is likely to break (other Dockerfiles in this repo pin a ref and init submodules before building a wheel). Consider restoring pinning (FA_BRANCH/FA_SHA), initializing submodules, and installing via a built wheel (or pip install .) with explicit GPU_ARCHS=${HIP_ARCHITECTURES} to avoid long/incorrect builds.
| RUN git clone https://github.com/ROCm/flash-attention.git &&\ | |
| cd flash-attention &&\ | |
| python setup.py install | |
| ARG FA_REPO="https://github.com/ROCm/flash-attention.git" | |
| ARG FA_BRANCH="v3.0.0.r1-cktile" | |
| ARG FA_SHA="" | |
| RUN cd ${APP_DIR} \ | |
| && rm -rf flash-attention \ | |
| && git clone --branch ${FA_BRANCH} ${FA_REPO} flash-attention \ | |
| && cd flash-attention \ | |
| && if [ -n "${FA_SHA}" ]; then git checkout ${FA_SHA}; fi \ | |
| && git submodule update --init --recursive \ | |
| && GPU_ARCHS=${HIP_ARCHITECTURES} python3 -m pip wheel . --no-build-isolation --wheel-dir=dist \ | |
| && pip install dist/*.whl \ | |
| && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')" |
| # Install flash attention | ||
| ARG BUILD_FA="1" | ||
| ARG FA_BRANCH="v3.0.0.r1-cktile" | ||
| ARG FA_REPO="https://github.com/ROCm/flash-attention.git" | ||
| RUN if [ "$BUILD_FA" = "1" ]; then \ | ||
| cd ${APP_DIR} \ | ||
| && pip uninstall -y flash-attention \ | ||
| && rm -rf flash-attention \ | ||
| && git clone ${FA_REPO} \ | ||
| && cd flash-attention \ | ||
| && git checkout ${FA_BRANCH} \ | ||
| && git submodule update --init \ | ||
| && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \ | ||
| && pip install dist/*.whl \ | ||
| && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \ | ||
| fi | ||
| #ARG BUILD_FA="1" | ||
| #ARG FA_BRANCH="v3.0.0.r1-cktile" | ||
| #ARG FA_REPO="https://github.com/ROCm/flash-attention.git" | ||
| #RUN if [ "$BUILD_FA" = "1" ]; then \ | ||
| # cd ${APP_DIR} \ | ||
| # && pip uninstall -y flash-attention \ | ||
| # && rm -rf flash-attention \ | ||
| # && git clone ${FA_REPO} \ | ||
| # && cd flash-attention \ | ||
| # && git checkout ${FA_BRANCH} \ | ||
| # && git submodule update --init \ | ||
| # && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \ | ||
| # && pip install dist/*.whl \ | ||
| # && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \ | ||
| # fi |
There was a problem hiding this comment.
The large commented-out flash-attention build block (BUILD_FA/FA_BRANCH/FA_REPO) looks like dead code after the new install path was added. To keep the Dockerfile maintainable, either remove the commented block or re-enable it behind a build ARG so there is a single authoritative installation path.
| && sed -i 's/params: AttrDict = {}/params: AttrDict = field(default_factory=dict)/g' "$f"; \ | ||
| done \ | ||
| && sed -i '/self\.language_model = LlamaForCausalLM(language_config)/a\ self.post_init()' janus/models/modeling_vlm.py \ | ||
| && sed -i '/self\.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)/a\ self.post_init()' janus/janusflow/models/modeling_vlm.py \ |
There was a problem hiding this comment.
The Janus/transformers 5.x compatibility changes are applied via sed against specific source lines. sed will still exit successfully even if the patterns stop matching upstream, which can silently produce an unpatched image and reintroduce the runtime errors this PR is addressing. Consider pinning Janus to a known commit and/or adding post-sed assertions (e.g., grep for default_factory=dict and self.post_init()) so the build fails if the patch didn’t apply.
| && sed -i 's/params: AttrDict = {}/params: AttrDict = field(default_factory=dict)/g' "$f"; \ | |
| done \ | |
| && sed -i '/self\.language_model = LlamaForCausalLM(language_config)/a\ self.post_init()' janus/models/modeling_vlm.py \ | |
| && sed -i '/self\.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)/a\ self.post_init()' janus/janusflow/models/modeling_vlm.py \ | |
| && sed -i 's/params: AttrDict = {}/params: AttrDict = field(default_factory=dict)/g' "$f" \ | |
| && grep -q 'from dataclasses import field' "$f" \ | |
| && grep -q 'params: AttrDict = field(default_factory=dict)' "$f"; \ | |
| done \ | |
| && sed -i '/self\.language_model = LlamaForCausalLM(language_config)/a\ self.post_init()' janus/models/modeling_vlm.py \ | |
| && grep -q 'self.post_init()' janus/models/modeling_vlm.py \ | |
| && sed -i '/self\.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)/a\ self.post_init()' janus/janusflow/models/modeling_vlm.py \ | |
| && grep -q 'self.post_init()' janus/janusflow/models/modeling_vlm.py \ |
Motivation
Fixed FA compilation steps , as the steps are taken from https://github.com/dao-ailab/flash-attention ROCM support repo and fixed AttributeError: 'MultiModalityCausalLM' object has no attribute 'all_tied_weights_keys'. Did you mean: '_tied_weights_keys'?
Technical Details
Janus + Transformers 5.x — PretrainedConfig / dataclasses
Change: After clone, for both
janus/models/modeling_vlm.py and janus/janusflow/models/modeling_vlm.py:
add from dataclasses import field after the PretrainedConfig import;
replace params: AttrDict = {} with params: AttrDict = field(default_factory=dict).
Why: Transformers 5.x turns those configs into dataclasses; a plain {} is an illegal mutable default and caused:
ValueError: mutable default … for field params … use default_factory.
3. Janus + Transformers 5.x — from_pretrained / weight loading
Change: append self.post_init() at the end of MultiModalityCausalLM.init in:
Main model: right after self.language_model = LlamaForCausalLM(language_config).
Janus-Flow model: right after self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True).
Why: In Transformers 5.x, all_tied_weights_keys (and related setup) is applied in PreTrainedModel.post_init(). Janus never called it, which caused:
AttributeError: 'MultiModalityCausalLM' object has no attribute 'all_tied_weights_keys'.
Test Plan
Test Result
Submission Checklist