diff --git a/docs/source/Installation.md b/docs/source/Installation.md index 70fa95efa..ca5ee0d45 100644 --- a/docs/source/Installation.md +++ b/docs/source/Installation.md @@ -30,6 +30,19 @@ to install GPU accelerated {doc}`cuEquivariance attention kernels `, us pip install openfold3[cuequivariance] ``` +To use AMD ROCm-compatible Triton kernels, first install the ROCm PyTorch wheel (which bundles ROCm Triton), then install openfold3: + +```bash +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2 +pip install openfold3 +``` + +After installation, verify your ROCm environment is correctly configured: + +```bash +validate-openfold3-rocm +``` + (installation-environment-variables)= ### Environment variables diff --git a/docs/source/inference.md b/docs/source/inference.md index ca35dadd3..e8daf122a 100644 --- a/docs/source/inference.md +++ b/docs/source/inference.md @@ -197,6 +197,7 @@ We provide several example runner files in our [examples directory](https://gith - Using low memory settings - Customizing output formats - Enabling cuEquivariance kernels +- Enabling AMD ROCm Triton kernels - Saving MSA and Template processing outputs - And more @@ -297,6 +298,36 @@ model_update: --- +#### 🔴 AMD ROCm Inference with Triton Kernels + +On AMD GPUs, OpenFold3 can use native Triton kernels for the Evoformer attention and TriangleMultiplicativeUpdate layers instead of the default CUDA-specific kernels. + +First, install PyTorch for ROCm and openfold3 (see [Installation](https://github.com/aqlaboratory/openfold-3/blob/main/docs/source/Installation.md)). +Then enable the Triton kernels in your `runner.yml` using the provided [`triton.yml`](https://github.com/aqlaboratory/openfold-3/blob/main/examples/example_runner_yamls/triton.yml) example: + +```yaml +model_update: + presets: + - predict + custom: + settings: + memory: + eval: + use_triton_triangle_kernels: true + use_deepspeed_evo_attention: false + use_cueq_triangle_kernels: false +``` + +```bash +run_openfold predict \ + --query-json /path/to/query.json \ + --output-dir /path/to/output/ \ + --runner-yaml examples/example_runner_yamls/triton.yml +``` + +> **Note on first-run compilation**: Triton JIT-compiles kernels on first use and caches them to `~/.triton/cache`. The compilation is a one-time cost per unique sequence length per machine; subsequent runs at the same length incur no overhead. + +--- ### 3.4 Customized ColabFold MSA Server Settings Using `runner.yml` diff --git a/environments/production-amd-linux-64.yml b/environments/production-amd-linux-64.yml new file mode 100644 index 000000000..f85a14c3e --- /dev/null +++ b/environments/production-amd-linux-64.yml @@ -0,0 +1,39 @@ +name: openfold3-env +channels: + - conda-forge + - bioconda + - pytorch +dependencies: + - python + - awscli + - setuptools + - pip + - conda-forge::uv + - pytorch-lightning + - biopython + - numpy + - pandas + - PyYAML + - requests + - scipy + - tqdm + - typing-extensions + - wandb + - modelcif + - ml-collections + - mkl + - rdkit=2025.09.3 + - biotite==1.2.0 + - bioconda::hmmer + - bioconda::hhsuite + - bioconda::kalign2 + - memory_profiler + - func_timeout + - boto3 + - conda-forge::python-lmdb=1.6 + - conda-forge::ijson + - pip: + - pdbeccdutils + - --extra-index-url https://download.pytorch.org/whl/rocm7.2 + - torch + - torchvision diff --git a/examples/example_runner_yamls/triton.yml b/examples/example_runner_yamls/triton.yml new file mode 100644 index 000000000..b614acff7 --- /dev/null +++ b/examples/example_runner_yamls/triton.yml @@ -0,0 +1,11 @@ +model_update: + presets: + - predict + - low_mem # to use low memory settings + custom: + settings: + memory: + eval: + use_triton_triangle_kernels: true + use_deepspeed_evo_attention: false + use_cueq_triangle_kernels: false diff --git a/openfold3/core/kernels/triton/evoformer.py b/openfold3/core/kernels/triton/evoformer.py new file mode 100644 index 000000000..03a40ff11 --- /dev/null +++ b/openfold3/core/kernels/triton/evoformer.py @@ -0,0 +1,1126 @@ +# Copyright 2026 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +try: + import triton + import triton.language as tl + + _TRITON_AVAILABLE = True +except ImportError: + _TRITON_AVAILABLE = False + +# Sentinel: replaced with EvoformerAttention.apply when Triton is available. +TritonEvoformer = None + +if _TRITON_AVAILABLE: + + def is_hip(): + """Check if the current backend is HIP.""" + return triton.runtime.driver.active.get_current_target().backend == "hip" + + @triton.jit + def _attn_fwd_inner( + O_block, + l_i, + m_i, + Q_block, + K_block_ptr, + V_block_ptr, + res_mask_block_ptr, + pair_bias_block_ptr, + block_index_q, + DIM, + stride_K_seq, + stride_V_seq, + stride_mask_seq, + stride_pair_bias_seq2, + softmax_scale, + EVEN_Q: tl.constexpr, + EVEN_KV: tl.constexpr, + EVEN_DIM: tl.constexpr, + HAS_PAIR_BIAS: tl.constexpr, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_KV: tl.constexpr, + BLOCK_DIM: tl.constexpr, + offs_q: tl.constexpr, + offs_kv: tl.constexpr, + offs_d: tl.constexpr, + SEQ_LEN: tl.constexpr, + ): + """Run the inner loop of the forward pass of the attention mechanism.""" + lo, hi = 0, SEQ_LEN + Q_block = Q_block * tl.full((1,), softmax_scale, dtype=Q_block.dtype) + + for start_kv in range(lo, hi, BLOCK_SIZE_KV): + start_kv = tl.multiple_of(start_kv, BLOCK_SIZE_KV) + if EVEN_Q & EVEN_KV: + if HAS_PAIR_BIAS: + pair_bias_block = tl.load(pair_bias_block_ptr) + res_mask_block = tl.load(res_mask_block_ptr).broadcast_to( + (BLOCK_SIZE_Q, BLOCK_SIZE_KV) + ) + if EVEN_DIM: + K_block = tl.load(K_block_ptr) + V_block = tl.load(V_block_ptr) + else: + K_block = tl.load( + K_block_ptr, mask=offs_d[:, None] < DIM, other=0.0 + ) + V_block = tl.load( + V_block_ptr, mask=offs_d[None, :] < DIM, other=0.0 + ) + else: + if HAS_PAIR_BIAS: + pair_bias_block = tl.load( + pair_bias_block_ptr, + mask=(offs_q[:, None] < SEQ_LEN) + & ((start_kv + offs_kv)[None, :] < SEQ_LEN), + other=float("-inf"), + ) + res_mask_block = tl.load( + res_mask_block_ptr, + mask=(start_kv + offs_kv)[None, :] < SEQ_LEN, + other=float("-inf"), + ).broadcast_to((BLOCK_SIZE_Q, BLOCK_SIZE_KV)) + if EVEN_DIM: + K_block = tl.load( + K_block_ptr, + mask=(start_kv + offs_kv)[None, :] < SEQ_LEN, + other=0.0, + ) + V_block = tl.load( + V_block_ptr, + mask=(start_kv + offs_kv)[:, None] < SEQ_LEN, + other=0.0, + ) + else: + K_block = tl.load( + K_block_ptr, + mask=((start_kv + offs_kv)[None, :] < SEQ_LEN) + & (offs_d[:, None] < DIM), + other=0.0, + ) + V_block = tl.load( + V_block_ptr, + mask=((start_kv + offs_kv)[:, None] < SEQ_LEN) + & (offs_d[None, :] < DIM), + other=0.0, + ) + + QK_block = tl.dot(Q_block, K_block) + res_mask_block + if HAS_PAIR_BIAS: + QK_block += pair_bias_block + + if not EVEN_KV: + QK_block += tl.where( + (start_kv + offs_kv)[None, :] < SEQ_LEN, 0, float("-inf") + ) + + m_ij = tl.maximum(m_i, tl.max(QK_block, 1)) + QK_block = QK_block - m_ij[:, None] + + P_block = tl.math.exp(QK_block) + l_ij = tl.sum(P_block, 1) + + alpha = tl.math.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + + P_block = P_block.to(V_block.dtype) + O_block = O_block * alpha[:, None] + O_block = tl.dot(P_block, V_block, O_block) + + m_i = m_ij + + V_block_ptr += BLOCK_SIZE_KV * stride_V_seq + K_block_ptr += BLOCK_SIZE_KV * stride_K_seq + if HAS_PAIR_BIAS: + pair_bias_block_ptr += BLOCK_SIZE_KV * stride_pair_bias_seq2 + res_mask_block_ptr += BLOCK_SIZE_KV * stride_mask_seq + + return O_block, l_i, m_i + + @triton.heuristics( + { + "EVEN_Q": lambda args: args["SEQ_LEN"] % args["BLOCK_SIZE_Q"] == 0, + "EVEN_KV": lambda args: args["SEQ_LEN"] % args["BLOCK_SIZE_KV"] == 0, + "EVEN_DIM": lambda args: args["DIM"] == args["BLOCK_DIM"], + } + ) + @triton.jit + def _attn_fwd( + Q, # BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM + K, # BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM + V, # BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM + res_mask, # BATCH_SIZE, N_SEQ, 1, SEQ_LEN, 1 + pair_bias, # BATCH_SIZE, 1, HEAD, SEQ_LEN, SEQ_LEN + softmax_scale, + M, # BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN + O, # BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM + stride_Q_batch, + stride_Q_msa, + stride_Q_head, + stride_Q_seq, + stride_Q_dim, + stride_K_batch, + stride_K_msa, + stride_K_head, + stride_K_seq, + stride_K_dim, + stride_V_batch, + stride_V_msa, + stride_V_head, + stride_V_seq, + stride_V_dim, + stride_O_batch, + stride_O_msa, + stride_O_head, + stride_O_seq, + stride_O_dim, + stride_pair_bias_batch, + stride_pair_bias_head, + stride_pair_bias_seq1, + stride_pair_bias_seq2, + stride_mask_batch, + stride_mask_msa, + stride_mask_seq, + BATCH_SIZE, + HEAD: tl.constexpr, + N_SEQ: tl.constexpr, + SEQ_LEN: tl.constexpr, + DIM: tl.constexpr, + EVEN_Q: tl.constexpr, + EVEN_KV: tl.constexpr, + EVEN_DIM: tl.constexpr, + HAS_PAIR_BIAS: tl.constexpr, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_KV: tl.constexpr, + BLOCK_DIM: tl.constexpr, + ): + """Run the forward pass of the attention mechanism.""" + block_index_q = tl.program_id(0) + + index_batch_msa_head = tl.program_id(1) + index_batch_msa = index_batch_msa_head // HEAD + index_head = index_batch_msa_head % HEAD + index_batch = index_batch_msa // N_SEQ + index_msa = index_batch_msa % N_SEQ + + # Cast to int64 to avoid int32 overflow for large sequences + qvk_offset = ( + index_batch.to(tl.int64) * stride_Q_batch + + index_msa.to(tl.int64) * stride_Q_msa + + index_head * stride_Q_head + ) + offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + offs_kv = tl.arange(0, BLOCK_SIZE_KV) + offs_d = tl.arange(0, BLOCK_DIM) + + Q_block_ptr = ( + Q + qvk_offset + (offs_q[:, None] * stride_Q_seq + offs_d[None, :]) + ) + V_block_ptr = ( + V + qvk_offset + (offs_kv[:, None] * stride_V_seq + offs_d[None, :]) + ) + K_block_ptr = ( + K + qvk_offset + (offs_kv[None, :] * stride_K_seq + offs_d[:, None]) + ) + pair_bias_block_ptr = ( + pair_bias + + index_batch * stride_pair_bias_batch + + index_head * stride_pair_bias_head + + ( + offs_q[:, None] * stride_pair_bias_seq1 + + offs_kv[None, :] * stride_pair_bias_seq2 + ) + ) + O_block_ptr = ( + O + qvk_offset + (offs_q[:, None] * stride_O_seq + offs_d[None, :]) + ) + + res_mask_block_ptr = ( + res_mask + + index_batch * stride_mask_batch + + index_msa * stride_mask_msa + + (offs_kv[None, :] * stride_mask_seq) + ) + + m_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) + 1.0 + O_block = tl.zeros([BLOCK_SIZE_Q, BLOCK_DIM], dtype=tl.float32) + + # Load Q block; it stays in SRAM for the duration of the inner loop + if EVEN_Q & EVEN_KV: + if EVEN_DIM: + Q_block = tl.load(Q_block_ptr) + else: + Q_block = tl.load(Q_block_ptr, mask=offs_d[None, :] < DIM, other=0.0) + else: + if EVEN_DIM: + Q_block = tl.load( + Q_block_ptr, mask=offs_q[:, None] < SEQ_LEN, other=0.0 + ) + else: + Q_block = tl.load( + Q_block_ptr, + mask=(offs_q[:, None] < SEQ_LEN) & (offs_d[None, :] < DIM), + other=0.0, + ) + + O_block, l_i, m_i = _attn_fwd_inner( + O_block, + l_i, + m_i, + Q_block, + K_block_ptr, + V_block_ptr, + res_mask_block_ptr, + pair_bias_block_ptr, + block_index_q, + DIM, + stride_K_seq, + stride_V_seq, + stride_mask_seq, + stride_pair_bias_seq2, + softmax_scale, + EVEN_Q, + EVEN_KV, + EVEN_DIM, + HAS_PAIR_BIAS, + BLOCK_SIZE_Q, + BLOCK_SIZE_KV, + BLOCK_DIM, + offs_q, + offs_kv, + offs_d, + SEQ_LEN, + ) + + m_i += tl.math.log(l_i) + O_block = O_block / l_i[:, None] + O_block = O_block.to(O.type.element_ty) + m_ptrs = M + index_batch_msa_head * SEQ_LEN + offs_q + + if EVEN_Q: + tl.store(m_ptrs, m_i) + if EVEN_DIM: + tl.store(O_block_ptr, O_block) + else: + tl.store(O_block_ptr, O_block, mask=offs_d[None, :] < DIM) + else: + tl.store(m_ptrs, m_i, mask=offs_q < SEQ_LEN) + if EVEN_DIM: + tl.store(O_block_ptr, O_block, mask=offs_q[:, None] < SEQ_LEN) + else: + tl.store( + O_block_ptr, + O_block, + mask=(offs_q[:, None] < SEQ_LEN) & (offs_d[None, :] < DIM), + ) + + @triton.jit + def _attn_bwd_preprocess( + O, + dO, + D, + SEQ_LEN, + BLOCK_SIZE_Q: tl.constexpr, + DIM: tl.constexpr, + BLOCK_DIM: tl.constexpr, + ): + """Run the preprocessing step of the backward pass of the attention + mechanism.""" + block_index_q = tl.program_id(0) + offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + index_batch_msa_head = tl.program_id(1) + offs_dim = tl.arange(0, BLOCK_DIM) + + # Cast to int64 to avoid int32 overflow for large sequences + bwd_offset = index_batch_msa_head.to(tl.int64) * SEQ_LEN * DIM + + # Load a single block of BLOCK_SIZE_Q rows of O + O_block = tl.load( + O + bwd_offset + offs_q[:, None] * DIM + offs_dim[None, :], + mask=(offs_q[:, None] < SEQ_LEN) & (offs_dim[None, :] < DIM), + other=0.0, + ) + # Load a single block of BLOCK_SIZE_Q rows of dO + dO_block = tl.load( + dO + bwd_offset + offs_q[:, None] * DIM + offs_dim[None, :], + mask=(offs_q[:, None] < SEQ_LEN) & (offs_dim[None, :] < DIM), + other=0.0, + ).to(tl.float32) + # Compute the D block + D_block = tl.sum(dO_block * O_block, axis=1) # Shape: (BLOCK_SIZE_Q,) + # Store the D block + D_block_ptrs = D + index_batch_msa_head.to(tl.int64) * SEQ_LEN + offs_q + tl.store(D_block_ptrs, D_block, mask=offs_q < SEQ_LEN) + + @triton.heuristics( + { + "EVEN_Q": lambda args: args["SEQ_LEN"] % args["BLOCK_SIZE_Q"] == 0, + "EVEN_KV": lambda args: args["SEQ_LEN"] % args["BLOCK_SIZE_KV"] == 0, + "EVEN_DIM": lambda args: args["DIM"] == args["BLOCK_DIM"], + } + ) + @triton.jit + def _attn_bwd_dq( + Q, + K, + V, + res_mask, + pair_bias, + softmax_scale, + dO, + dQ, + dK, + dV, + d_pair_bias, + M, + D, + stride_batch, + stride_head, + stride_msa, + stride_seq, + stride_pair_bias_batch, + stride_pair_bias_head, + stride_pair_bias_seq1, + stride_pair_bias_seq2, + stride_mask_batch, + stride_mask_msa, + stride_mask_seq, + stride_d_pair_bias_batch, + stride_d_pair_bias_head, + stride_d_pair_bias_seq1, + stride_d_pair_bias_seq2, + HEAD, + N_SEQ, + SEQ_LEN, + BLOCK_DIM: tl.constexpr, + DIM: tl.constexpr, + EVEN_Q: tl.constexpr, + EVEN_KV: tl.constexpr, + EVEN_DIM: tl.constexpr, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_KV: tl.constexpr, + ): + """Run the backward pass of the attention mechanism.""" + index_batch_msa_head = tl.program_id(1) + index_batch_msa = index_batch_msa_head // HEAD + index_head = index_batch_msa_head % HEAD + index_batch = index_batch_msa // N_SEQ + index_msa = index_batch_msa % N_SEQ + + # Cast indices to int64 to avoid int32 overflow + offset_batch_head_msa = ( + index_batch.to(tl.int64) * stride_batch + + index_head.to(tl.int64) * stride_head + + index_msa.to(tl.int64) * stride_msa + ) + offset_batch_head_msa_seq = index_batch_msa_head.to(tl.int64) * SEQ_LEN + + Q += offset_batch_head_msa + K += offset_batch_head_msa + V += offset_batch_head_msa + dO += offset_batch_head_msa + dQ += offset_batch_head_msa + dK += offset_batch_head_msa + dV += offset_batch_head_msa + + M += offset_batch_head_msa_seq + D += offset_batch_head_msa_seq + + offs_dim = tl.arange(0, BLOCK_DIM) + + index_block_kv = tl.program_id(0) + offs_q = index_block_kv * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + + dQ_block = tl.zeros([BLOCK_SIZE_Q, BLOCK_DIM], dtype=tl.float32) + + if EVEN_Q & EVEN_KV: + M_block = tl.load(M + offs_q) + Di = tl.load(D + offs_q) + if EVEN_DIM: + Q_block = tl.load(Q + offs_q[:, None] * stride_seq + offs_dim[None, :]) + dO_block = tl.load( + dO + offs_q[:, None] * stride_seq + offs_dim[None, :] + ) + else: + Q_block = tl.load( + Q + offs_q[:, None] * stride_seq + offs_dim[None, :], + mask=offs_dim[None, :] < DIM, + other=0.0, + ) + dO_block = tl.load( + dO + offs_q[:, None] * stride_seq + offs_dim[None, :], + mask=offs_dim[None, :] < DIM, + other=0.0, + ) + else: + M_block = tl.load(M + offs_q, mask=offs_q < SEQ_LEN, other=0.0) + Di = tl.load(D + offs_q, mask=offs_q < SEQ_LEN, other=0.0) + if EVEN_DIM: + Q_block = tl.load( + Q + offs_q[:, None] * stride_seq + offs_dim[None, :], + mask=offs_q[:, None] < SEQ_LEN, + other=0.0, + ) + dO_block = tl.load( + dO + offs_q[:, None] * stride_seq + offs_dim[None, :], + mask=offs_q[:, None] < SEQ_LEN, + other=0.0, + ) + else: + Q_block = tl.load( + Q + offs_q[:, None] * stride_seq + offs_dim[None, :], + mask=(offs_q[:, None] < SEQ_LEN) & (offs_dim[None, :] < DIM), + other=0.0, + ) + dO_block = tl.load( + dO + offs_q[:, None] * stride_seq + offs_dim[None, :], + mask=(offs_q[:, None] < SEQ_LEN) & (offs_dim[None, :] < DIM), + other=0.0, + ) + + M_block = M_block[:, None] + + offs_kv = tl.arange(0, BLOCK_SIZE_KV) + pair_bias_block_ptr = ( + pair_bias + + index_batch.to(tl.int64) * stride_pair_bias_batch + + index_head.to(tl.int64) * stride_pair_bias_head + + offs_q[:, None] * stride_pair_bias_seq1 + + offs_kv[None, :] * stride_pair_bias_seq2 + ) + + d_pair_bias_block_ptr = ( + d_pair_bias + + index_batch.to(tl.int64) * stride_d_pair_bias_batch + + index_head.to(tl.int64) * stride_d_pair_bias_head + + (offs_q[:, None] * stride_d_pair_bias_seq1) + + (offs_kv[None, :] * stride_d_pair_bias_seq2) + ) + res_mask_block_ptr = ( + res_mask + + index_batch.to(tl.int64) * stride_mask_batch + + index_msa.to(tl.int64) * stride_mask_msa + + (offs_kv[None, :] * stride_mask_seq) + ) + + kT_ptrs = K + offs_kv[None, :] * stride_seq + offs_dim[:, None] + vT_ptrs = V + offs_kv[None, :] * stride_seq + offs_dim[:, None] + + Q_block = Q_block * tl.full((1,), softmax_scale, dtype=Q_block.dtype) + + curr_kv = 0 + num_steps = (SEQ_LEN + BLOCK_SIZE_KV - 1) // BLOCK_SIZE_KV + + for blk_idx in range(num_steps): + if EVEN_Q & EVEN_KV: + pair_bias_block = tl.load(pair_bias_block_ptr) + res_mask_block = tl.load(res_mask_block_ptr).broadcast_to( + (BLOCK_SIZE_Q, BLOCK_SIZE_KV) + ) + if EVEN_DIM: + K_T_block = tl.load(kT_ptrs) + V_T_block = tl.load(vT_ptrs) + else: + K_T_block = tl.load( + kT_ptrs, mask=offs_dim[:, None] < DIM, other=0.0 + ) + V_T_block = tl.load( + vT_ptrs, mask=offs_dim[:, None] < DIM, other=0.0 + ) + else: + pair_bias_block = tl.load( + pair_bias_block_ptr, + mask=(offs_q[:, None] < SEQ_LEN) + & ((blk_idx * BLOCK_SIZE_KV + offs_kv)[None, :] < SEQ_LEN), + other=float("-inf"), + ) + res_mask_block = tl.load( + res_mask_block_ptr, + mask=(blk_idx * BLOCK_SIZE_KV + offs_kv)[None, :] < SEQ_LEN, + other=float("-inf"), + ).broadcast_to((BLOCK_SIZE_Q, BLOCK_SIZE_KV)) + if EVEN_DIM: + K_T_block = tl.load( + kT_ptrs, + mask=(blk_idx * BLOCK_SIZE_KV + offs_kv)[None, :] < SEQ_LEN, + other=0.0, + ) + V_T_block = tl.load( + vT_ptrs, + mask=(blk_idx * BLOCK_SIZE_KV + offs_kv)[None, :] < SEQ_LEN, + other=0.0, + ) + else: + K_T_block = tl.load( + kT_ptrs, + mask=((blk_idx * BLOCK_SIZE_KV + offs_kv)[None, :] < SEQ_LEN) + & (offs_dim[:, None] < DIM), + other=0.0, + ) + V_T_block = tl.load( + vT_ptrs, + mask=((blk_idx * BLOCK_SIZE_KV + offs_kv)[None, :] < SEQ_LEN) + & (offs_dim[:, None] < DIM), + other=0.0, + ) + + QK_block = tl.dot(Q_block, K_T_block) + pair_bias_block + res_mask_block + + if not EVEN_KV: + QK_block += tl.where( + (blk_idx * BLOCK_SIZE_KV + offs_kv)[None, :] < SEQ_LEN, + 0, + float("-inf"), + ) + + P_block = tl.math.exp(QK_block - M_block) + + dP_block = tl.dot(dO_block, V_T_block).to(tl.float32) + dS_block = P_block * (dP_block - Di[:, None]) + + # Update d_pair_bias atomic add with float32 precision + tl.atomic_add( + d_pair_bias_block_ptr, + dS_block, + mask=(offs_q[:, None] < SEQ_LEN) + & ((blk_idx * BLOCK_SIZE_KV + offs_kv)[None, :] < SEQ_LEN), + ) + dS_block = dS_block.to(K_T_block.dtype) + + dQ_block += softmax_scale * tl.dot(dS_block, tl.trans(K_T_block)) + + curr_kv += BLOCK_SIZE_KV + kT_ptrs += BLOCK_SIZE_KV * stride_seq + vT_ptrs += BLOCK_SIZE_KV * stride_seq + pair_bias_block_ptr += BLOCK_SIZE_KV * stride_pair_bias_seq2 + d_pair_bias_block_ptr += BLOCK_SIZE_KV * stride_d_pair_bias_seq2 + res_mask_block_ptr += BLOCK_SIZE_KV * stride_mask_seq + + dQ_block_ptrs = dQ + offs_q[:, None] * stride_seq + offs_dim[None, :] + if EVEN_Q & EVEN_KV: + if EVEN_DIM: + tl.store(dQ_block_ptrs, dQ_block) + else: + tl.store(dQ_block_ptrs, dQ_block, mask=offs_dim[None, :] < DIM) + else: + if EVEN_DIM: + tl.store(dQ_block_ptrs, dQ_block, mask=offs_q[:, None] < SEQ_LEN) + else: + tl.store( + dQ_block_ptrs, + dQ_block, + mask=(offs_q[:, None] < SEQ_LEN) & (offs_dim[None, :] < DIM), + ) + + @triton.heuristics( + { + "EVEN_Q": lambda args: args["SEQ_LEN"] % args["BLOCK_SIZE_Q"] == 0, + "EVEN_KV": lambda args: args["SEQ_LEN"] % args["BLOCK_SIZE_KV"] == 0, + "EVEN_DIM": lambda args: args["DIM"] == args["BLOCK_DIM"], + } + ) + @triton.jit + def _attn_bwd_dk_dv( + Q, + K, + V, + res_mask, + pair_bias, + softmax_scale, + dO, + dQ, + dK, + dV, + M, + D, + stride_batch, + stride_head, + stride_msa, + stride_seq, + stride_pair_bias_batch, + stride_pair_bias_head, + stride_pair_bias_seq1, + stride_pair_bias_seq2, + stride_mask_batch, + stride_mask_msa, + stride_mask_seq, + HEAD, + N_SEQ, + SEQ_LEN, + BLOCK_DIM: tl.constexpr, + DIM: tl.constexpr, + EVEN_Q: tl.constexpr, + EVEN_KV: tl.constexpr, + EVEN_DIM: tl.constexpr, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_KV: tl.constexpr, + ): + """Run the backward pass of the attention mechanism.""" + index_batch_msa_head = tl.program_id(1) + index_batch_msa = index_batch_msa_head // HEAD + index_head = index_batch_msa_head % HEAD + index_batch = index_batch_msa // N_SEQ + index_msa = index_batch_msa % N_SEQ + + # Cast indices to int64 to avoid int32 overflow + offset_batch_msa_head = ( + index_batch.to(tl.int64) * stride_batch + + index_msa.to(tl.int64) * stride_msa + + index_head.to(tl.int64) * stride_head + ) + offset_batch_msa_head_seq = index_batch_msa_head.to(tl.int64) * SEQ_LEN + + Q += offset_batch_msa_head + K += offset_batch_msa_head + V += offset_batch_msa_head + dO += offset_batch_msa_head + dQ += offset_batch_msa_head + dK += offset_batch_msa_head + dV += offset_batch_msa_head + + M += offset_batch_msa_head_seq + D += offset_batch_msa_head_seq + + offs_dim = tl.arange(0, BLOCK_DIM) + + index_block_kv = tl.program_id(0) + offs_kv = index_block_kv * BLOCK_SIZE_KV + tl.arange(0, BLOCK_SIZE_KV) + offs_q = tl.arange(0, BLOCK_SIZE_Q) + + dK_block = tl.zeros([BLOCK_SIZE_KV, BLOCK_DIM], dtype=tl.float32) + dV_block = tl.zeros([BLOCK_SIZE_KV, BLOCK_DIM], dtype=tl.float32) + + res_mask_block_ptr = ( + res_mask + + index_batch.to(tl.int64) * stride_mask_batch + + index_msa.to(tl.int64) * stride_mask_msa + + offs_kv[None, :] * stride_mask_seq + ) + + # K and V stay in SRAM throughout the inner loop + if EVEN_Q & EVEN_KV: + res_mask_T_block = tl.trans(tl.load(res_mask_block_ptr)).broadcast_to( + (BLOCK_SIZE_KV, BLOCK_SIZE_Q) + ) + if EVEN_DIM: + K_block = tl.load( + K + offs_kv[:, None] * stride_seq + offs_dim[None, :] + ) # Shape: (BLOCK_SIZE_KV, DIM) + V_block = tl.load( + V + offs_kv[:, None] * stride_seq + offs_dim[None, :] + ) # Shape: (BLOCK_SIZE_KV, DIM) + else: + K_block = tl.load( + K + offs_kv[:, None] * stride_seq + offs_dim[None, :], + mask=offs_dim[None, :] < DIM, + other=0.0, + ) # Shape: (BLOCK_SIZE_KV, DIM) + V_block = tl.load( + V + offs_kv[:, None] * stride_seq + offs_dim[None, :], + mask=offs_dim[None, :] < DIM, + other=0.0, + ) # Shape: (BLOCK_SIZE_KV, DIM) + else: + res_mask_T_block = tl.trans( + tl.load( + res_mask_block_ptr, + mask=offs_kv[None, :] < SEQ_LEN, + other=float("-inf"), + ) + ).broadcast_to((BLOCK_SIZE_KV, BLOCK_SIZE_Q)) + if EVEN_DIM: + K_block = tl.load( + K + offs_kv[:, None] * stride_seq + offs_dim[None, :], + mask=offs_kv[:, None] < SEQ_LEN, + other=0.0, + ) + V_block = tl.load( + V + offs_kv[:, None] * stride_seq + offs_dim[None, :], + mask=offs_kv[:, None] < SEQ_LEN, + other=0.0, + ) + else: + K_block = tl.load( + K + offs_kv[:, None] * stride_seq + offs_dim[None, :], + mask=(offs_kv[:, None] < SEQ_LEN) & (offs_dim[None, :] < DIM), + other=0.0, + ) + V_block = tl.load( + V + offs_kv[:, None] * stride_seq + offs_dim[None, :], + mask=(offs_kv[:, None] < SEQ_LEN) & (offs_dim[None, :] < DIM), + other=0.0, + ) + + pair_bias_T_block_ptr = ( + pair_bias + + ( + index_batch.to(tl.int64) * stride_pair_bias_batch + + index_head.to(tl.int64) * stride_pair_bias_head + ) + + offs_q[None, :] * stride_pair_bias_seq1 + + offs_kv[:, None] * stride_pair_bias_seq2 + ) + qT_ptrs = Q + offs_q[None, :] * stride_seq + offs_dim[:, None] + dO_ptrs = dO + offs_q[:, None] * stride_seq + offs_dim[None, :] + + K_block = K_block * tl.full((1,), softmax_scale, dtype=K_block.dtype) + + curr_q = 0 + num_steps = (SEQ_LEN + BLOCK_SIZE_Q - 1) // BLOCK_SIZE_Q + + for _blk_idx in range(num_steps): + offs_q = curr_q + tl.arange(0, BLOCK_SIZE_Q) + + if EVEN_Q & EVEN_KV: + m = tl.load(M + offs_q) + pair_bias_T_block = tl.load(pair_bias_T_block_ptr) + Di = tl.load(D + offs_q) # [(BLOCK_SIZE_Q, )] + if EVEN_DIM: + qT_block = tl.load(qT_ptrs) + dO_block = tl.load(dO_ptrs) + else: + qT_block = tl.load(qT_ptrs, mask=offs_dim[:, None] < DIM, other=0.0) + dO_block = tl.load(dO_ptrs, mask=offs_dim[None, :] < DIM, other=0.0) + else: + m = tl.load(M + offs_q, mask=offs_q < SEQ_LEN, other=0.0) + pair_bias_T_block = tl.load( + pair_bias_T_block_ptr, + mask=(offs_q[None, :] < SEQ_LEN) & (offs_kv[:, None] < SEQ_LEN), + other=float("-inf"), + ) + Di = tl.load(D + offs_q, mask=offs_q < SEQ_LEN, other=0.0) + if EVEN_DIM: + qT_block = tl.load( + qT_ptrs, mask=offs_q[None, :] < SEQ_LEN, other=0.0 + ) + dO_block = tl.load( + dO_ptrs, mask=offs_q[:, None] < SEQ_LEN, other=0.0 + ) + else: + qT_block = tl.load( + qT_ptrs, + mask=(offs_q[None, :] < SEQ_LEN) & (offs_dim[:, None] < DIM), + other=0.0, + ) + dO_block = tl.load( + dO_ptrs, + mask=(offs_q[:, None] < SEQ_LEN) & (offs_dim[None, :] < DIM), + other=0.0, + ) + + # Compute P^T = K Q^T (transposed attention scores) + QK_T_block = ( + tl.dot(K_block, qT_block) + pair_bias_T_block + res_mask_T_block + ) + + if not (EVEN_Q & EVEN_KV): + QK_T_block += tl.where( + (offs_kv[:, None] < SEQ_LEN) & (offs_q[None, :] < SEQ_LEN), + 0, + float("-inf"), + ) + + P_T_block = tl.math.exp(QK_T_block - m[None, :]) + + dV_block += tl.dot(P_T_block.to(K_block.dtype), dO_block) + + dpT_block = tl.dot(V_block, tl.trans(dO_block)).to(tl.float32) + dS_T_block = P_T_block * (dpT_block - Di[None, :]) + dS_T_block = dS_T_block.to(K_block.dtype) + + dK_block += softmax_scale * tl.dot(dS_T_block, tl.trans(qT_block)) + + # Increment pointers + curr_q += BLOCK_SIZE_Q + qT_ptrs += BLOCK_SIZE_Q * stride_seq + dO_ptrs += BLOCK_SIZE_Q * stride_seq + pair_bias_T_block_ptr += BLOCK_SIZE_Q * stride_pair_bias_seq1 + + dV_block_ptrs = dV + offs_kv[:, None] * stride_seq + offs_dim[None, :] + dK_block_ptrs = dK + offs_kv[:, None] * stride_seq + offs_dim[None, :] + + if EVEN_Q & EVEN_KV: + if EVEN_DIM: + tl.store(dV_block_ptrs, dV_block) + tl.store(dK_block_ptrs, dK_block) + else: + tl.store(dV_block_ptrs, dV_block, mask=offs_dim[None, :] < DIM) + tl.store(dK_block_ptrs, dK_block, mask=offs_dim[None, :] < DIM) + else: + if EVEN_DIM: + tl.store(dV_block_ptrs, dV_block, mask=offs_kv[:, None] < SEQ_LEN) + tl.store(dK_block_ptrs, dK_block, mask=offs_kv[:, None] < SEQ_LEN) + else: + tl.store( + dV_block_ptrs, + dV_block, + mask=(offs_kv[:, None] < SEQ_LEN) & (offs_dim[None, :] < DIM), + ) + tl.store( + dK_block_ptrs, + dK_block, + mask=(offs_kv[:, None] < SEQ_LEN) & (offs_dim[None, :] < DIM), + ) + + class EvoformerAttention(torch.autograd.Function): + @staticmethod + def forward(ctx, Q, K, V, res_mask, pair_bias, has_pair_bias=True): + """Run the forward pass of the attention mechanism. + + has_pair_bias: set False when pair_bias is all-zeros (MSA column attention). + This eliminates all pair_bias HBM loads in the forward kernel. + """ + # Q, K, V: [Batch, N_seq, N_res, Head, Dim] + # res_mask: [Batch, N_seq, 1, 1, N_res] + # pair_bias: [Batch, 1, Head, N_res, N_res] + + DIM_Q, DIM_K, DIM_V = Q.shape[-1], K.shape[-1], V.shape[-1] + assert DIM_Q == DIM_K and DIM_K == DIM_V + + Q = Q.transpose( + -2, -3 + ).contiguous() # (BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM) + K = K.transpose( + -2, -3 + ).contiguous() # (BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM) + V = V.transpose( + -2, -3 + ).contiguous() # (BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM) + + BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM = Q.shape + softmax_scale = DIM**-0.5 + BLOCK_DIM = max(triton.next_power_of_2(DIM), 32) + + O = torch.empty_like(Q) + + extra_kern_args = {} + if is_hip(): + waves_per_eu = 3 if DIM <= 64 else 2 + extra_kern_args = { + "waves_per_eu": waves_per_eu, + "allow_flush_denorm": True, + } + + block_size_q = 64 + + grid = lambda args: ( # noqa: E731 + triton.cdiv(SEQ_LEN, args["BLOCK_SIZE_Q"]), + BATCH_SIZE * N_SEQ * HEAD, + 1, + ) + + # M is the logsumexp for the backward pass, one for each query + M = torch.empty( + (BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN), device=Q.device, dtype=torch.float32 + ) + + _attn_fwd[grid]( + Q=Q, + K=K, + V=V, + res_mask=res_mask, + pair_bias=pair_bias, + softmax_scale=softmax_scale, + M=M, + O=O, + stride_Q_batch=Q.stride(0), + stride_Q_msa=Q.stride(1), + stride_Q_head=Q.stride(2), + stride_Q_seq=Q.stride(3), + stride_Q_dim=Q.stride(4), + stride_K_batch=K.stride(0), + stride_K_msa=K.stride(1), + stride_K_head=K.stride(2), + stride_K_seq=K.stride(3), + stride_K_dim=K.stride(4), + stride_V_batch=V.stride(0), + stride_V_msa=V.stride(1), + stride_V_head=V.stride(2), + stride_V_seq=V.stride(3), + stride_V_dim=V.stride(4), + stride_O_batch=O.stride(0), + stride_O_msa=O.stride(1), + stride_O_head=O.stride(2), + stride_O_seq=O.stride(3), + stride_O_dim=O.stride(4), + stride_pair_bias_batch=pair_bias.stride(0), + stride_pair_bias_head=pair_bias.stride(2), + stride_pair_bias_seq1=pair_bias.stride(3), + stride_pair_bias_seq2=pair_bias.stride(4), + stride_mask_batch=res_mask.stride(0), + stride_mask_msa=res_mask.stride(1), + stride_mask_seq=res_mask.stride(4), + BATCH_SIZE=BATCH_SIZE, + HEAD=HEAD, + N_SEQ=N_SEQ, + SEQ_LEN=SEQ_LEN, + DIM=DIM, + BLOCK_DIM=BLOCK_DIM, + HAS_PAIR_BIAS=has_pair_bias, + BLOCK_SIZE_Q=block_size_q, + BLOCK_SIZE_KV=16, + num_warps=4, + num_stages=1, + **extra_kern_args, + ) + + ctx.save_for_backward(Q, K, V, res_mask, pair_bias, O, M) + ctx.grid = grid + ctx.softmax_scale = softmax_scale + ctx.DIM = DIM + + O = O.transpose(-2, -3).contiguous() + + return O + + @staticmethod + def backward(ctx, dO): + """Run the backward pass of the attention mechanism.""" + + Q, K, V, res_mask, pair_bias, O, M = ctx.saved_tensors + dO = dO.transpose( + -2, -3 + ).contiguous() # (BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM) + + assert Q.stride() == K.stride() == V.stride() == O.stride() == dO.stride() + dQ = torch.empty_like(Q) + dK = torch.empty_like(K) + dV = torch.empty_like(V) + + BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM = dQ.shape + + d_pair_bias = torch.empty( + (BATCH_SIZE, 1, HEAD, SEQ_LEN, SEQ_LEN), + device=pair_bias.device, + dtype=torch.float32, + ).zero_() + + BLOCK_DIM = max(triton.next_power_of_2(DIM), 32) + + D = torch.empty_like(M) # Shape: (BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN) + + preprocess_grid = lambda args: ( # noqa: E731 + triton.cdiv(SEQ_LEN, args["BLOCK_SIZE_Q"]), + BATCH_SIZE * N_SEQ * HEAD, + 1, + ) + _attn_bwd_preprocess[preprocess_grid]( + O=O, + dO=dO, + D=D, + SEQ_LEN=SEQ_LEN, + DIM=DIM, + BLOCK_DIM=BLOCK_DIM, + BLOCK_SIZE_Q=16, + num_warps=4, + num_stages=2, + ) + + bwd_dk_dv_grid = lambda args: ( # noqa: E731 + triton.cdiv(SEQ_LEN, args["BLOCK_SIZE_KV"]), + BATCH_SIZE * N_SEQ * HEAD, + 1, + ) + _attn_bwd_dk_dv[bwd_dk_dv_grid]( + Q=Q, + K=K, + V=V, + res_mask=res_mask, + pair_bias=pair_bias, + softmax_scale=ctx.softmax_scale, + dO=dO, + dQ=dQ, + dK=dK, + dV=dV, + M=M, + D=D, + stride_batch=Q.stride(0), + stride_msa=Q.stride(1), + stride_head=Q.stride(2), + stride_seq=Q.stride(3), + stride_pair_bias_batch=pair_bias.stride(0), + stride_pair_bias_head=pair_bias.stride(2), + stride_pair_bias_seq1=pair_bias.stride(3), + stride_pair_bias_seq2=pair_bias.stride(4), + stride_mask_batch=res_mask.stride(0), + stride_mask_msa=res_mask.stride(1), + stride_mask_seq=res_mask.stride(4), + HEAD=HEAD, + N_SEQ=N_SEQ, + SEQ_LEN=SEQ_LEN, + BLOCK_DIM=BLOCK_DIM, + DIM=ctx.DIM, + BLOCK_SIZE_Q=64, + BLOCK_SIZE_KV=64, + num_warps=4, + num_stages=1, + ) + + bwd_dq_grid = lambda args: ( # noqa: E731 + triton.cdiv(SEQ_LEN, args["BLOCK_SIZE_Q"]), + BATCH_SIZE * N_SEQ * HEAD, + 1, + ) + _attn_bwd_dq[bwd_dq_grid]( + Q=Q, + K=K, + V=V, + res_mask=res_mask, + pair_bias=pair_bias, + softmax_scale=ctx.softmax_scale, + dO=dO, + dQ=dQ, + dK=dK, + dV=dV, + d_pair_bias=d_pair_bias, + M=M, + D=D, + stride_batch=Q.stride(0), + stride_msa=Q.stride(1), + stride_head=Q.stride(2), + stride_seq=Q.stride(3), + stride_pair_bias_batch=pair_bias.stride(0), + stride_pair_bias_head=pair_bias.stride(2), + stride_pair_bias_seq1=pair_bias.stride(3), + stride_pair_bias_seq2=pair_bias.stride(4), + stride_mask_batch=res_mask.stride(0), + stride_mask_msa=res_mask.stride(1), + stride_mask_seq=res_mask.stride(4), + stride_d_pair_bias_batch=d_pair_bias.stride(0), + stride_d_pair_bias_head=d_pair_bias.stride(2), + stride_d_pair_bias_seq1=d_pair_bias.stride(3), + stride_d_pair_bias_seq2=d_pair_bias.stride(4), + HEAD=HEAD, + N_SEQ=N_SEQ, + SEQ_LEN=SEQ_LEN, + BLOCK_DIM=BLOCK_DIM, + DIM=ctx.DIM, + BLOCK_SIZE_Q=16, + BLOCK_SIZE_KV=16, + num_warps=4, + num_stages=1, + ) + + dQ = dQ.transpose(-2, -3).contiguous() + dK = dK.transpose(-2, -3).contiguous() + dV = dV.transpose(-2, -3).contiguous() + + return dQ, dK, dV, None, d_pair_bias.to(dO.dtype), None + + TritonEvoformer = EvoformerAttention.apply diff --git a/openfold3/core/model/heads/head_modules.py b/openfold3/core/model/heads/head_modules.py index beff5a44a..46a132ac4 100644 --- a/openfold3/core/model/heads/head_modules.py +++ b/openfold3/core/model/heads/head_modules.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -97,6 +98,7 @@ def forward( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, offload_inference: bool = False, @@ -127,6 +129,8 @@ def forward( use_cueq_triangle_kernels: Whether to use cuEq triangle attention kernel. Mutually exclusive with use_lma + use_triton_triangle_kernels: + Whether to use Triton triangle attention kernel. use_lma: Whether to use low-memory attention during inference. Mutually exclusive with use_deepspeed_evo_attention. @@ -203,6 +207,7 @@ def forward( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, offload_inference=offload_inference, diff --git a/openfold3/core/model/heads/prediction_heads.py b/openfold3/core/model/heads/prediction_heads.py index 1500eca60..e957bc68c 100644 --- a/openfold3/core/model/heads/prediction_heads.py +++ b/openfold3/core/model/heads/prediction_heads.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -124,6 +125,7 @@ def per_sample_pairformer_emb( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, offload_inference: bool = False, @@ -155,6 +157,7 @@ def per_sample_pairformer_emb( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, _mask_trans=_mask_trans, @@ -183,6 +186,7 @@ def pairformer_emb( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, _mask_trans: bool = True, @@ -209,7 +213,11 @@ def reshape_outputs(x: torch.Tensor, feat_dims: list): # in the DS kernel. To avoid this, chunk tuning is disabled in this case. # TODO: cuEq seems to fail comparison unit tests with the same settings, # disable for now and verify behavior - use_kernels = use_deepspeed_evo_attention or use_cueq_triangle_kernels + use_kernels = ( + use_deepspeed_evo_attention + or use_cueq_triangle_kernels + or use_triton_triangle_kernels + ) if use_kernels and si.shape[0] > 1: chunk_size = None @@ -221,6 +229,7 @@ def reshape_outputs(x: torch.Tensor, feat_dims: list): chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, _mask_trans=_mask_trans, @@ -242,6 +251,7 @@ def forward( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, offload_inference: bool = False, @@ -303,6 +313,7 @@ def forward( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, offload_inference=offload_inference, @@ -319,6 +330,7 @@ def forward( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, _mask_trans=_mask_trans, diff --git a/openfold3/core/model/latent/base_blocks.py b/openfold3/core/model/latent/base_blocks.py index c16e70fff..f902c73d7 100644 --- a/openfold3/core/model/latent/base_blocks.py +++ b/openfold3/core/model/latent/base_blocks.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -206,6 +207,7 @@ def forward( transition_ckpt_chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, _mask_trans: bool = True, @@ -310,6 +312,7 @@ def tri_mul_out_in( pair_mask: torch.Tensor, inplace_safe: bool, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, ) -> torch.Tensor: """Perform the outgoing and incoming triangular multiplicative updates.""" inplace_safe = inplace_safe and (not use_cueq_triangle_kernels) @@ -325,6 +328,7 @@ def tri_mul_out_in( inplace_safe=inplace_safe, _add_with_inplace=True, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, ) if not inplace_safe: z = z + self.ps_dropout_row_layer(tmu_update) @@ -342,6 +346,7 @@ def tri_mul_out_in( inplace_safe=inplace_safe, _add_with_inplace=True, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, ) if not inplace_safe: z = z + self.ps_dropout_row_layer(tmu_update) @@ -357,8 +362,9 @@ def tri_att_start_end( pair_mask: torch.Tensor, use_deepspeed_evo_attention: bool, use_cueq_triangle_kernels: bool, - use_lma: bool, - inplace_safe: bool, + use_triton_triangle_kernels: bool = False, + use_lma: bool = False, + inplace_safe: bool = False, ) -> torch.Tensor: """Perform the starting and ending triangular attention layers.""" z = add( @@ -370,6 +376,7 @@ def tri_att_start_end( chunk_size=_attn_chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, ) @@ -392,6 +399,7 @@ def tri_att_start_end( chunk_size=_attn_chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, ) @@ -412,6 +420,7 @@ def forward( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, _mask_trans: bool = True, @@ -455,6 +464,7 @@ def forward( pair_mask=pair_mask, inplace_safe=inplace_safe, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, ) z = self.tri_att_start_end( @@ -463,6 +473,7 @@ def forward( pair_mask=pair_mask, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, ) diff --git a/openfold3/core/model/latent/base_stacks.py b/openfold3/core/model/latent/base_stacks.py index f0463e65f..87a37f37c 100644 --- a/openfold3/core/model/latent/base_stacks.py +++ b/openfold3/core/model/latent/base_stacks.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2025 NVIDIA Corporation # Copyright 2021 DeepMind Technologies Limited # @@ -84,6 +85,7 @@ def _prep_blocks( transition_ckpt_chunk_size: int | None, use_deepspeed_evo_attention: bool, use_cueq_triangle_kernels: bool, + use_triton_triangle_kernels: bool, use_lma: bool, msa_mask: torch.Tensor | None, pair_mask: torch.Tensor | None, @@ -106,6 +108,7 @@ def _prep_blocks( transition_ckpt_chunk_size=transition_ckpt_chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, _mask_trans=_mask_trans, @@ -181,6 +184,7 @@ def forward_offload( transition_ckpt_chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, _mask_trans: bool = True, ): @@ -195,6 +199,7 @@ def forward_offload( transition_ckpt_chunk_size=transition_ckpt_chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, msa_mask=msa_mask, pair_mask=pair_mask, @@ -227,6 +232,7 @@ def forward( transition_ckpt_chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, _mask_trans: bool = True, @@ -269,6 +275,7 @@ def forward( transition_ckpt_chunk_size=transition_ckpt_chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, msa_mask=msa_mask, pair_mask=pair_mask, diff --git a/openfold3/core/model/latent/evoformer.py b/openfold3/core/model/latent/evoformer.py index 6fb6e9ece..e9b3278ee 100644 --- a/openfold3/core/model/latent/evoformer.py +++ b/openfold3/core/model/latent/evoformer.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -93,6 +94,7 @@ def forward( transition_ckpt_chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, _mask_trans: bool = True, @@ -134,6 +136,7 @@ def forward( chunk_size=_attn_chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, ) ), @@ -158,6 +161,7 @@ def forward( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, ), inplace=inplace_safe, @@ -208,6 +212,7 @@ def forward( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, _mask_trans=_mask_trans, diff --git a/openfold3/core/model/latent/msa_module.py b/openfold3/core/model/latent/msa_module.py index af8a79870..ac216687e 100644 --- a/openfold3/core/model/latent/msa_module.py +++ b/openfold3/core/model/latent/msa_module.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -148,6 +149,7 @@ def forward( transition_ckpt_chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, _mask_trans: bool = True, @@ -246,6 +248,7 @@ def forward( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, _mask_trans=_mask_trans, diff --git a/openfold3/core/model/latent/pairformer.py b/openfold3/core/model/latent/pairformer.py index c2e5104ae..6f75923db 100644 --- a/openfold3/core/model/latent/pairformer.py +++ b/openfold3/core/model/latent/pairformer.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2025 NVIDIA Corporation # Copyright 2021 DeepMind Technologies Limited # @@ -130,6 +131,7 @@ def forward( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, _mask_trans: bool = True, @@ -156,6 +158,9 @@ def forward( update kernel and attention kernel. When both this and use_deepspeed_evo_attention are True, the cuEquivariance kernel is only used for triangle attention + use_triton_triangle_kernels: + Whether to use Triton triangle attention kernel. + Mutually exclusive with use_deepspeed_evo_attention. use_lma: Whether to use low-memory attention during inference. Mutually exclusive with use_deepspeed_evo_attention. @@ -180,6 +185,7 @@ def forward( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, _mask_trans=_mask_trans, @@ -195,6 +201,7 @@ def forward( mask=single_mask, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, ), inplace=inplace_safe, @@ -325,6 +332,7 @@ def _prep_blocks( chunk_size: int | None, use_deepspeed_evo_attention: bool, use_cueq_triangle_kernels: bool, + use_triton_triangle_kernels: bool, use_lma: bool, inplace_safe: bool, _mask_trans: bool, @@ -345,6 +353,7 @@ def _prep_blocks( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, _mask_trans=_mask_trans, @@ -404,6 +413,7 @@ def forward( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, _mask_trans: bool = True, @@ -424,6 +434,9 @@ def forward( use_deepspeed_evo_attention: Whether to use DeepSpeed memory efficient kernel. Mutually exclusive with use_lma. + use_triton_triangle_kernels: + Whether to use Triton triangle attention kernel. + Mutually exclusive with use_deepspeed_evo_attention. use_lma: Whether to use low-memory attention during inference. Mutually exclusive with use_deepspeed_evo_attention. @@ -445,6 +458,7 @@ def forward( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, _mask_trans=_mask_trans, diff --git a/openfold3/core/model/latent/template_module.py b/openfold3/core/model/latent/template_module.py index 89d291b42..820edf1a2 100644 --- a/openfold3/core/model/latent/template_module.py +++ b/openfold3/core/model/latent/template_module.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2025 NVIDIA Corporation # Copyright 2021 DeepMind Technologies Limited # @@ -116,10 +117,11 @@ def _forward_single_template( chunk_size: int | None, use_deepspeed_evo_attention: bool, use_cueq_triangle_kernels: bool, - use_lma: bool, - inplace_safe: bool, - _mask_trans: bool, - _attn_chunk_size: int | None, + use_triton_triangle_kernels: bool = False, + use_lma: bool = False, + inplace_safe: bool = False, + _mask_trans: bool = True, + _attn_chunk_size: int | None = None, ): """ Helper function to process exactly one template slice. @@ -128,11 +130,17 @@ def _forward_single_template( # t: [1, N, N, C] if self.tri_mul_first: t = self.tri_att_start_end( - z=self.tri_mul_out_in(z=t, pair_mask=mask, inplace_safe=inplace_safe), + z=self.tri_mul_out_in( + z=t, + pair_mask=mask, + inplace_safe=inplace_safe, + use_triton_triangle_kernels=use_triton_triangle_kernels, + ), _attn_chunk_size=_attn_chunk_size, pair_mask=mask, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, ) @@ -144,11 +152,13 @@ def _forward_single_template( pair_mask=mask, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, ), pair_mask=mask, inplace_safe=inplace_safe, + use_triton_triangle_kernels=use_triton_triangle_kernels, ) t = add( @@ -170,6 +180,7 @@ def forward( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, _mask_trans: bool = True, @@ -186,6 +197,8 @@ def forward( use_deepspeed_evo_attention: Whether to use DeepSpeed memory efficient kernel. Mutually exclusive with use_lma. + use_triton_triangle_kernels: + Whether to use Triton triangle attention kernel. use_lma: Whether to use low-memory attention during inference. Mutually exclusive with use_deepspeed_evo_attention. @@ -212,6 +225,7 @@ def forward( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, _mask_trans=_mask_trans, @@ -350,6 +364,7 @@ def _prep_blocks( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, _mask_trans: bool = True, @@ -361,6 +376,7 @@ def _prep_blocks( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, _mask_trans=_mask_trans, @@ -404,6 +420,7 @@ def forward( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, _mask_trans: bool = True, @@ -419,6 +436,8 @@ def forward( use_deepspeed_evo_attention: Whether to use DeepSpeed memory efficient kernel. Mutually exclusive with use_lma. + use_triton_triangle_kernels: + Whether to use Triton triangle attention kernel. use_lma: Whether to use low-memory attention during inference. Mutually exclusive with use_deepspeed_evo_attention. @@ -441,6 +460,7 @@ def forward( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, _mask_trans=_mask_trans, @@ -491,6 +511,7 @@ def forward( _mask_trans: bool = True, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, ) -> torch.Tensor: @@ -534,6 +555,7 @@ def forward( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, inplace_safe=inplace_safe, _mask_trans=_mask_trans, diff --git a/openfold3/core/model/layers/attention_pair_bias.py b/openfold3/core/model/layers/attention_pair_bias.py index 975bd76c6..59712ceaf 100644 --- a/openfold3/core/model/layers/attention_pair_bias.py +++ b/openfold3/core/model/layers/attention_pair_bias.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -172,6 +173,7 @@ def forward( mask: torch.Tensor | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, use_high_precision_attention: bool = False, ) -> torch.Tensor: @@ -188,6 +190,8 @@ def forward( [*, N] Mask for token or atom-level embedding use_deepspeed_evo_attention: Whether to use DeepSpeed Evo Attention kernel + use_triton_triangle_kernels: + Whether to use Triton triangle attention kernel use_lma: Whether to use LMA use_high_precision_attention: @@ -204,7 +208,9 @@ def forward( # Current reshape function only expects missing batch dim batch_dims = a.shape[:-2] reshape_for_ds_kernel = ( - use_deepspeed_evo_attention or use_cueq_triangle_kernels + use_deepspeed_evo_attention + or use_cueq_triangle_kernels + or use_triton_triangle_kernels ) and len(batch_dims) == 1 if reshape_for_ds_kernel: a = a.unsqueeze(1) @@ -216,6 +222,7 @@ def forward( biases=biases, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, use_high_precision=use_high_precision_attention, ) @@ -375,6 +382,7 @@ def forward( mask: torch.Tensor | None = None, use_high_precision_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, ) -> torch.Tensor: """ Args: @@ -419,6 +427,7 @@ def forward( biases=biases, use_high_precision=use_high_precision_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, ) # Convert back to unpadded and flattened atom representation diff --git a/openfold3/core/model/layers/diffusion_transformer.py b/openfold3/core/model/layers/diffusion_transformer.py index c6216e6eb..6577df4e1 100644 --- a/openfold3/core/model/layers/diffusion_transformer.py +++ b/openfold3/core/model/layers/diffusion_transformer.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -122,6 +123,7 @@ def forward( mask: torch.Tensor | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, use_high_precision_attention: bool = False, _mask_trans: bool = True, @@ -138,6 +140,8 @@ def forward( [*, N] Mask for token-level embedding use_deepspeed_evo_attention: Whether to use DeepSpeed Evo Attention kernel + use_triton_triangle_kernels: + Whether to use Triton triangle attention kernel use_lma: Whether to use LMA use_high_precision_attention: @@ -155,6 +159,7 @@ def forward( mask=mask, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, use_high_precision_attention=use_high_precision_attention, ) @@ -273,6 +278,7 @@ def forward( mask: torch.Tensor | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, use_high_precision_attention: bool = False, _mask_trans: bool = True, @@ -291,6 +297,8 @@ def forward( Whether to use DeepSpeed Evo Attention kernel use_cueq_triangle_kernels: Whether to use cuEq triangle kernels + use_triton_triangle_kernels: + Whether to use Triton triangle attention kernel use_lma: Whether to use LMA use_high_precision_attention: @@ -310,6 +318,7 @@ def forward( mask=mask, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, use_high_precision_attention=use_high_precision_attention, _mask_trans=_mask_trans, diff --git a/openfold3/core/model/layers/msa.py b/openfold3/core/model/layers/msa.py index 3acd47c29..bf476d71d 100644 --- a/openfold3/core/model/layers/msa.py +++ b/openfold3/core/model/layers/msa.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -109,6 +110,7 @@ def _chunk( chunk_size: int, use_deepspeed_evo_attention: bool, use_cueq_triangle_kernels: bool, + use_triton_triangle_kernels: bool, use_lma: bool, ) -> torch.Tensor: def fn(m, biases): @@ -119,6 +121,7 @@ def fn(m, biases): biases=biases, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, ) @@ -231,6 +234,7 @@ def forward( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, _chunk_logits: int | None = None, @@ -274,6 +278,7 @@ def forward( chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, ) else: @@ -284,6 +289,7 @@ def forward( biases=biases, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, ) @@ -383,6 +389,7 @@ def forward( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, ) -> torch.Tensor: """ @@ -407,6 +414,7 @@ def forward( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, ) diff --git a/openfold3/core/model/layers/triangular_attention.py b/openfold3/core/model/layers/triangular_attention.py index 2d17a0a52..82631de40 100644 --- a/openfold3/core/model/layers/triangular_attention.py +++ b/openfold3/core/model/layers/triangular_attention.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -76,6 +77,7 @@ def _chunk( chunk_size: int, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, ) -> torch.Tensor: @@ -91,6 +93,7 @@ def _chunk( self.mha, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, ), mha_inputs, @@ -106,6 +109,7 @@ def forward( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, inplace_safe: bool = False, ) -> torch.Tensor: @@ -149,6 +153,7 @@ def forward( use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_lma=use_lma, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, inplace_safe=inplace_safe, ) else: @@ -159,6 +164,7 @@ def forward( use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_lma=use_lma, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, ) if not self.starting: diff --git a/openfold3/core/model/layers/triangular_multiplicative_update.py b/openfold3/core/model/layers/triangular_multiplicative_update.py index 9f779b6f8..2f33b6d21 100644 --- a/openfold3/core/model/layers/triangular_multiplicative_update.py +++ b/openfold3/core/model/layers/triangular_multiplicative_update.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -33,9 +34,579 @@ if is_cuequivariance_available(): from cuequivariance_torch import triangle_multiplicative_update +try: + import triton + import triton.language as tl + + TRITON_AVAILABLE = True +except ImportError: + TRITON_AVAILABLE = False warnings.filterwarnings("once") +if TRITON_AVAILABLE: + + @triton.jit + def sigmoid_mul_kernel( + x_ptr, + gate_ptr, + output_ptr, + N, + BLOCK_SIZE: tl.constexpr, + ): + """ + Fused kernel: output = x * sigmoid(gate) + """ + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + if block_start >= N: + return + + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = (offs >= 0) & (offs < N) + + safe_offs = offs.to(tl.int64) + x_vals = tl.load(x_ptr + safe_offs, mask=mask, other=0.0) + gate_vals = tl.load(gate_ptr + safe_offs, mask=mask, other=0.0) + + sigmoid_gate = 1.0 / (1.0 + tl.exp(-gate_vals)) + result = x_vals * sigmoid_gate + + tl.store(output_ptr + safe_offs, result, mask=mask) + + @triton.jit + def layernorm_kernel( + x_ptr, + output_ptr, + weight_ptr, + bias_ptr, + M, + N, + eps, + BLOCK_SIZE: tl.constexpr, + ROWS_PER_PROGRAM: tl.constexpr, + ): + """ + Vectorized LayerNorm kernel with multi-row processing and safe bounds checking. + Each program processes ROWS_PER_PROGRAM rows to amortize launch overhead. + Uses vectorized tl.sum reductions instead of scalar Welford loop. + Normalizes over the last dimension (N). + """ + row_start = tl.program_id(0) * ROWS_PER_PROGRAM + + # Early exit if out of bounds + if row_start >= M: + return + + for row in range(ROWS_PER_PROGRAM): + row_idx = row_start + row + if row_idx < M: + row_offset = row_idx.to(tl.int64) * N + x_row_ptr = x_ptr + row_offset + out_row_ptr = output_ptr + row_offset + + # Pass 1: vectorized mean and variance computation + sum_val = 0.0 + sum_sq = 0.0 + for block_start in range(0, N, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = (offs >= 0) & (offs < N) + safe_offs = offs.to(tl.int64) + x_block = tl.load(x_row_ptr + safe_offs, mask=mask, other=0.0).to( + tl.float32 + ) + sum_val += tl.sum(x_block, axis=0) + sum_sq += tl.sum(x_block * x_block, axis=0) + + mean = sum_val / N + var = sum_sq / N - mean * mean + rstd = 1.0 / tl.sqrt(tl.maximum(var, eps)) + + # Pass 2: normalize and apply affine transform + for block_start in range(0, N, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = (offs >= 0) & (offs < N) + + safe_offs = offs.to(tl.int64) + x_vals = tl.load(x_row_ptr + safe_offs, mask=mask, other=0.0).to( + tl.float32 + ) + w_vals = tl.load(weight_ptr + safe_offs, mask=mask, other=1.0).to( + tl.float32 + ) + b_vals = tl.load(bias_ptr + safe_offs, mask=mask, other=0.0).to( + tl.float32 + ) + + out = (x_vals - mean) * rstd * w_vals + b_vals + tl.store(out_row_ptr + safe_offs, out, mask=mask) + + @triton.jit + def linear_kernel( + x_ptr, + w_ptr, + bias_ptr, + output_ptr, + M, + K, + N, + stride_xm, + stride_xk, + stride_wn, + stride_wk, + stride_om, + stride_on, + HAS_BIAS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ): + """ + Linear layer: output = x @ weight.T + bias + x: [M, K], weight: [N, K] (PyTorch layout), bias: [N], output: [M, N] + """ + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + m_mask = rm < M + n_mask = rn < N + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k_start in range(0, K, BLOCK_K): + rk = k_start + tl.arange(0, BLOCK_K) + k_mask = rk < K + + x_offsets = ( + rm[:, None].to(tl.int64) * stride_xm + + rk[None, :].to(tl.int64) * stride_xk + ) + + w_offsets = ( + rn[None, :].to(tl.int64) * stride_wn + + rk[:, None].to(tl.int64) * stride_wk + ) + + x = tl.load( + x_ptr + x_offsets, + mask=m_mask[:, None] & k_mask[None, :], + other=0.0, + ) + + w = tl.load( + w_ptr + w_offsets, + mask=k_mask[:, None] & n_mask[None, :], + other=0.0, + ) + + acc = tl.dot(x, w, acc) + + if HAS_BIAS: + bias_offsets = rn.to(tl.int64) + bias = tl.load(bias_ptr + bias_offsets, mask=n_mask, other=0.0) + acc += bias[None, :] + + out_offsets = ( + rm[:, None].to(tl.int64) * stride_om + rn[None, :].to(tl.int64) * stride_on + ) + + out_mask = m_mask[:, None] & n_mask[None, :] + tl.store( + output_ptr + out_offsets, + acc, + mask=out_mask, + ) + + @triton.jit + def linear_fused_kernel( + x_ptr, + w_ptr, + bias_ptr, + other_ptr, + mask_ptr, + add_tensor_ptr, + output_ptr, + M, + K, + N, + stride_xm, + stride_xk, + stride_wn, + stride_wk, + stride_om, + stride_on, + stride_other_m, + stride_other_n, + stride_mask_m, + stride_mask_n, + stride_add_m, + stride_add_n, + HAS_BIAS: tl.constexpr, + APPLY_SIGMOID: tl.constexpr, + APPLY_MUL: tl.constexpr, + HAS_MASK: tl.constexpr, + HAS_ADD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ): + """ + Fused linear layer with optional sigmoid, elementwise multiply, mask, and add + Supports combinations: linear [+ sigmoid] [* other] [* mask] [+ add_tensor] + """ + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + m_mask = rm < M + n_mask = rn < N + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k_start in range(0, K, BLOCK_K): + rk = k_start + tl.arange(0, BLOCK_K) + k_mask = rk < K + + x_offsets = ( + rm[:, None].to(tl.int64) * stride_xm + + rk[None, :].to(tl.int64) * stride_xk + ) + + w_offsets = ( + rn[None, :].to(tl.int64) * stride_wn + + rk[:, None].to(tl.int64) * stride_wk + ) + + x = tl.load( + x_ptr + x_offsets, + mask=m_mask[:, None] & k_mask[None, :], + other=0.0, + ) + + w = tl.load( + w_ptr + w_offsets, + mask=k_mask[:, None] & n_mask[None, :], + other=0.0, + ) + + acc = tl.dot(x, w, acc) + + if HAS_BIAS: + bias_offsets = rn.to(tl.int64) + bias = tl.load(bias_ptr + bias_offsets, mask=n_mask, other=0.0) + acc += bias[None, :] + + if APPLY_SIGMOID: + # Clamp to avoid exp overflow + acc = tl.where(acc > 20.0, 20.0, acc) + acc = tl.where(acc < -20.0, -20.0, acc) + acc = 1.0 / (1.0 + tl.exp(-acc)) + + if APPLY_MUL: + other_offsets = ( + rm[:, None].to(tl.int64) * stride_other_m + + rn[None, :].to(tl.int64) * stride_other_n + ) + other_vals = tl.load( + other_ptr + other_offsets, + mask=m_mask[:, None] & n_mask[None, :], + other=0.0, + ) + acc = acc * other_vals + + if HAS_MASK: + mask_offsets = ( + rm[:, None].to(tl.int64) * stride_mask_m + + rn[None, :].to(tl.int64) * stride_mask_n + ) + mask_vals = tl.load( + mask_ptr + mask_offsets, + mask=m_mask[:, None] & n_mask[None, :], + other=0.0, + ) + acc = acc * mask_vals + + if HAS_ADD: + add_offsets = ( + rm[:, None].to(tl.int64) * stride_add_m + + rn[None, :].to(tl.int64) * stride_add_n + ) + add_vals = tl.load( + add_tensor_ptr + add_offsets, + mask=m_mask[:, None] & n_mask[None, :], + other=0.0, + ) + acc = acc + add_vals + + out_offsets = ( + rm[:, None].to(tl.int64) * stride_om + rn[None, :].to(tl.int64) * stride_on + ) + + out_mask = m_mask[:, None] & n_mask[None, :] + tl.store( + output_ptr + out_offsets, + acc, + mask=out_mask, + ) + + +def triton_sigmoid_mul(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + """ + Fused sigmoid + elementwise multiply: x * sigmoid(gate) + Supports broadcasting like PyTorch's native operations. + + Args: + x: Input tensor + gate: Gate tensor (broadcastable with x) + + Returns: + x * sigmoid(gate) + """ + broadcasted_shape = torch.broadcast_shapes(x.shape, gate.shape) + x_expanded = x.expand(broadcasted_shape) + gate_expanded = gate.expand(broadcasted_shape) + + x_flat = x_expanded.contiguous().reshape(-1) + gate_flat = gate_expanded.contiguous().reshape(-1) + N = x_flat.numel() + + output_flat = torch.empty_like(x_flat) + + BLOCK_SIZE = min(1024, triton.next_power_of_2(N)) + grid = (triton.cdiv(N, BLOCK_SIZE),) + + sigmoid_mul_kernel[grid](x_flat, gate_flat, output_flat, N, BLOCK_SIZE=BLOCK_SIZE) + + return output_flat.reshape(broadcasted_shape) + + +def triton_layernorm( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float = 1e-5 +) -> torch.Tensor: + """ + Triton LayerNorm with vectorized reductions and multi-row processing. + + Args: + x: Input tensor [..., N] + weight: Scale parameters [N] + bias: Shift parameters [N] + eps: Small constant for numerical stability + + Returns: + Normalized tensor with same shape as x + """ + input_shape = x.shape + N = input_shape[-1] + M = x.numel() // N + + x_2d = x.reshape(M, N).contiguous() + output_2d = torch.empty_like(x_2d) + + BLOCK_SIZE = min(1024, triton.next_power_of_2(N)) + ROWS_PER_PROGRAM = 8 + grid = (triton.cdiv(M, ROWS_PER_PROGRAM),) + + layernorm_kernel[grid]( + x_2d, + output_2d, + weight, + bias, + M, + N, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ROWS_PER_PROGRAM=ROWS_PER_PROGRAM, + ) + + return output_2d.reshape(input_shape) + + +def triton_linear( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None +) -> torch.Tensor: + """ + Triton Linear: output = x @ weight.T + bias + + Args: + x: Input tensor [..., K] + weight: Weight tensor [N, K] (same layout as nn.Linear.weight) + bias: Optional bias tensor [N] + + Returns: + Output tensor [..., N] + """ + input_shape = x.shape + K = input_shape[-1] + M = x.numel() // K + N = weight.shape[0] + + x_2d = x.reshape(M, K).contiguous() + output_2d = torch.empty((M, N), dtype=x.dtype, device=x.device) + + BLOCK_M = min(128, triton.next_power_of_2(M)) + BLOCK_N = min(64, triton.next_power_of_2(N)) + BLOCK_K = min(64, triton.next_power_of_2(K)) + + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + + bias_ptr = bias if bias is not None else weight # guarded by HAS_BIAS constexpr + + linear_kernel[grid]( + x_2d, + weight, + bias_ptr, + output_2d, + M, + K, + N, + x_2d.stride(0), + x_2d.stride(1), + weight.stride(0), + weight.stride(1), + output_2d.stride(0), + output_2d.stride(1), + HAS_BIAS=(bias is not None), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ) + + output_shape = input_shape[:-1] + (N,) + return output_2d.reshape(output_shape) + + +def triton_linear_fused( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + other: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + add_tensor: torch.Tensor | None = None, + apply_sigmoid: bool = False, +) -> torch.Tensor: + """ + Fused linear layer with optional sigmoid, elementwise multiply, mask, and add. + + Computes: linear(x) [* sigmoid] [* other] [* mask] [+ add_tensor] + + Args: + x: Input tensor [..., K] + weight: Weight tensor [N, K] (same layout as nn.Linear.weight) + bias: Optional bias tensor [N] + other: Optional tensor to multiply with (same shape as output) + mask: Optional mask tensor (broadcastable with output) + add_tensor: Optional tensor to add to result (same shape as output) + apply_sigmoid: Whether to apply sigmoid to linear result + + Returns: + Output tensor [..., N] + """ + input_shape = x.shape + K = input_shape[-1] + M = x.numel() // K + N = weight.shape[0] + + x_2d = x.reshape(M, K).contiguous() + output_2d = torch.empty((M, N), dtype=x.dtype, device=x.device) + + BLOCK_M = min(128, triton.next_power_of_2(M)) + BLOCK_N = min(64, triton.next_power_of_2(N)) + BLOCK_K = min(64, triton.next_power_of_2(K)) + + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + + if other is not None: + # other should match the output shape [..., N], so reshape accordingly + other_M = other.numel() // N + other_2d = other.reshape(other_M, N).contiguous() + if other_M != M: + raise ValueError( + f"Shape mismatch: tensor expects {other_M} rows but output has {M}" + ) + other_ptr = other_2d + other_stride_m, other_stride_n = other_2d.stride(0), other_2d.stride(1) + else: + other_ptr = x_2d # guarded by APPLY_MUL constexpr + other_stride_m, other_stride_n = 0, 0 + + if mask is not None: + # mask should broadcast to output shape [..., N] + mask_M = mask.numel() // (mask.shape[-1] if mask.numel() > 0 else 1) + mask_features = mask.shape[-1] if mask.ndim > 0 else 1 + if mask_features == 1: + # Broadcast mask to match output + mask_2d = mask.reshape(mask_M, 1).expand(M, N).contiguous() + else: + mask_2d = mask.reshape(mask_M, mask_features).contiguous() + if mask_M != M or mask_features != N: + raise ValueError( + f"Mask shape mismatch: got {mask_M}x{mask_features}, " + f"expected {M}x{N}" + ) + mask_ptr = mask_2d + mask_stride_m, mask_stride_n = mask_2d.stride(0), mask_2d.stride(1) + else: + mask_ptr = x_2d # guarded by HAS_MASK constexpr + mask_stride_m, mask_stride_n = 0, 0 + + if add_tensor is not None: + # add_tensor should match the output shape [..., N] + add_M = add_tensor.numel() // N + add_tensor_2d = add_tensor.reshape(add_M, N).contiguous() + if add_M != M: + raise ValueError( + f"Shape mismatch: add_tensor expects {add_M} rows but output has {M}" + ) + add_tensor_ptr = add_tensor_2d + add_stride_m, add_stride_n = add_tensor_2d.stride(0), add_tensor_2d.stride(1) + else: + add_tensor_ptr = x_2d # guarded by HAS_ADD constexpr + add_stride_m, add_stride_n = 0, 0 + + bias_ptr = bias if bias is not None else x_2d # guarded by HAS_BIAS constexpr + + linear_fused_kernel[grid]( + x_2d, + weight, + bias_ptr, + other_ptr, + mask_ptr, + add_tensor_ptr, + output_2d, + M, + K, + N, + x_2d.stride(0), + x_2d.stride(1), + weight.stride(0), + weight.stride(1), + output_2d.stride(0), + output_2d.stride(1), + other_stride_m, + other_stride_n, + mask_stride_m, + mask_stride_n, + add_stride_m, + add_stride_n, + HAS_BIAS=(bias is not None), + APPLY_SIGMOID=apply_sigmoid, + APPLY_MUL=(other is not None), + HAS_MASK=(mask is not None), + HAS_ADD=(add_tensor is not None), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ) + + output_shape = input_shape[:-1] + (N,) + return output_2d.reshape(output_shape) + class BaseTriangleMultiplicativeUpdate(nn.Module, ABC): """ @@ -156,6 +727,7 @@ def _inference_forward( mask: torch.Tensor | None = None, inplace_chunk_size: int | None = None, with_add: bool = True, + use_triton_triangle_kernels: bool = False, ): """ Args: @@ -221,7 +793,7 @@ def _inference_forward( mask = mask.unsqueeze(-1) - def compute_projection_helper(pair, mask, a=True): + def compute_projection_helper(pair, mask, a=True, use_triton=False): if a: linear_g = self.linear_a_g linear_p = self.linear_a_p @@ -229,18 +801,33 @@ def compute_projection_helper(pair, mask, a=True): linear_g = self.linear_b_g linear_p = self.linear_b_p - pair = self.layer_norm_in(pair) - p = linear_g(pair) - p.sigmoid_() - p *= linear_p(pair) - p *= mask + if use_triton: + pair = triton_layernorm( + pair, + self.layer_norm_in.weight, + self.layer_norm_in.bias, + self.layer_norm_in.eps, + ) + # Fused: sigmoid(linear_g(pair)) * linear_p(pair) * mask + p_g = triton_linear_fused( + pair, linear_g.weight, linear_g.bias, apply_sigmoid=True + ) + p = triton_linear_fused( + pair, linear_p.weight, linear_p.bias, other=p_g, mask=mask + ) + else: + pair = self.layer_norm_in(pair) + p = linear_g(pair) + p.sigmoid_() + p *= linear_p(pair) + p *= mask p = permute_final_dims(p, (2, 0, 1)) return p - def compute_projection(pair, mask, a=True, chunked=True): + def compute_projection(pair, mask, a=True, chunked=True, use_triton=False): need_transpose = self._outgoing ^ a if not chunked: - p = compute_projection_helper(pair, mask, a) + p = compute_projection_helper(pair, mask, a, use_triton) if need_transpose: p = p.transpose(-1, -2) else: @@ -255,6 +842,7 @@ def compute_projection(pair, mask, a=True, chunked=True): pair[..., i : i + inplace_chunk_size, :, :], mask[..., i : i + inplace_chunk_size, :, :], a, + use_triton, ) if need_transpose: pair_chunk = pair_chunk.transpose(-1, -2) @@ -269,7 +857,9 @@ def compute_projection(pair, mask, a=True, chunked=True): # We start by fully manifesting a. In addition to the input, this # brings total memory consumption to 2x z (disregarding size of chunks) # [*, N, N, c] - a = compute_projection(z, mask, True, chunked=True) + a = compute_projection( + z, mask, True, chunked=True, use_triton=use_triton_triangle_kernels + ) if inplace_chunk_size is not None: n = a.shape[-1] @@ -375,45 +965,125 @@ def flip_z_cache_(z_cache, z): z_chunk_b = slice_tensor( z_cache, z_cache_offset, z_cache_offset + offset, row_dim ) - b_chunk = compute_projection( - z_chunk_b, mask_chunk, a=False, chunked=False + z_chunk_b, + mask_chunk, + a=False, + chunked=False, + use_triton=use_triton_triangle_kernels, ) del z_chunk_b - x_chunk = torch.einsum("...ij,...jk->...ik", a, b_chunk) - x_chunk = permute_final_dims(x_chunk, (1, 2, 0)) - x_chunk = self.layer_norm_out(x_chunk) - x_chunk = self.linear_z(x_chunk) - - # The g dimension (col_dim) is parallel to and ahead of the - # overwrites in z. We can extract the g chunk normally. - z_chunk_g = slice_tensor(z, i, i + offset, col_dim) - g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g)) - g_chunk.sigmoid_() - del z_chunk_g + if use_triton_triangle_kernels: + x_chunk = torch.einsum("...ij,...jk->...ik", a, b_chunk) + x_chunk = permute_final_dims(x_chunk, (1, 2, 0)) + x_chunk = triton_layernorm( + x_chunk, + self.layer_norm_out.weight, + self.layer_norm_out.bias, + self.layer_norm_out.eps, + ) + x_chunk = triton_linear( + x_chunk, self.linear_z.weight, self.linear_z.bias + ) - x_chunk *= g_chunk + # The g dimension (col_dim) is parallel to and ahead of the + # overwrites in z. We can extract the g chunk normally. + z_chunk_g = slice_tensor(z, i, i + offset, col_dim) + g_input = triton_layernorm( + z_chunk_g, + self.layer_norm_in.weight, + self.layer_norm_in.bias, + self.layer_norm_in.eps, + ) + del z_chunk_g + + # Fused: sigmoid(linear(g_input)) * x_chunk [+ z_slice] + z_slicer = empty_slicer(z) + z_slicer[col_dim] = slice(i, i + offset) + if with_add: + z[tuple(z_slicer)] = triton_linear_fused( + g_input, + self.linear_g.weight, + self.linear_g.bias, + other=x_chunk, + add_tensor=z[tuple(z_slicer)], + apply_sigmoid=True, + ) + else: + # Fused: sigmoid(linear(g_input)) * x_chunk -> z[slice] + z[tuple(z_slicer)] = triton_linear_fused( + g_input, + self.linear_g.weight, + self.linear_g.bias, + other=x_chunk, + apply_sigmoid=True, + ) + else: + x_chunk = torch.einsum("...ij,...jk->...ik", a, b_chunk) + x_chunk = permute_final_dims(x_chunk, (1, 2, 0)) + x_chunk = self.layer_norm_out(x_chunk) + x_chunk = self.linear_z(x_chunk) + + # The g dimension (col_dim) is parallel to and ahead of the + # overwrites in z. We can extract the g chunk normally. + z_chunk_g = slice_tensor(z, i, i + offset, col_dim) + g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g)) + g_chunk.sigmoid_() + del z_chunk_g + + x_chunk *= g_chunk + + # Write the columns into z in-place + z_slicer = empty_slicer(z) + z_slicer[col_dim] = slice(i, i + offset) + if with_add: + z[tuple(z_slicer)] += x_chunk + else: + z[tuple(z_slicer)] = x_chunk - # Write the columns into z in-place - z_slicer = empty_slicer(z) - z_slicer[col_dim] = slice(i, i + offset) + else: + b = compute_projection( + z, mask, False, False, use_triton=use_triton_triangle_kernels + ) + if use_triton_triangle_kernels: + x = torch.einsum("...ij,...jk->...ik", a, b) + x = triton_layernorm( + x, + self.layer_norm_out.weight, + self.layer_norm_out.bias, + self.layer_norm_out.eps, + ) + x = triton_linear(x, self.linear_z.weight, self.linear_z.bias) + # Fused: sigmoid(linear(z)) * x [+ z] if with_add: - z[tuple(z_slicer)] += x_chunk + triton_linear_fused( + z, + self.linear_g.weight, + self.linear_g.bias, + other=x, + add_tensor=z, + apply_sigmoid=True, + ) else: - z[tuple(z_slicer)] = x_chunk - else: - b = compute_projection(z, mask, False, False) - x = torch.einsum("...ij,...jk->...ik", a, b) - x = self.layer_norm_out(x) - x = self.linear_z(x) - g = self.linear_g(z) - g.sigmoid_() - x *= g - if with_add: - z += x + z[:] = triton_linear_fused( + z, + self.linear_g.weight, + self.linear_g.bias, + other=x, + apply_sigmoid=True, + ) else: - z = x + x = torch.einsum("...ij,...jk->...ik", a, b) + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.linear_g(z) + g.sigmoid_() + x *= g + if with_add: + z += x + else: + z = x return z @@ -423,6 +1093,7 @@ def forward( mask: torch.Tensor | None = None, inplace_safe: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, _add_with_inplace: bool = False, _inplace_chunk_size: int | None = 256, ) -> torch.Tensor: @@ -475,6 +1146,7 @@ def forward( mask, inplace_chunk_size=_inplace_chunk_size, with_add=_add_with_inplace, + use_triton_triangle_kernels=use_triton_triangle_kernels, ) return x @@ -551,12 +1223,104 @@ def __init__( self.c_z, self.c_hidden * 2, **linear_init_params.linear_ab_g ) + def _triton_inference_forward( + self, + z: torch.Tensor, + mask: torch.Tensor | None = None, + _inplace_chunk_size: int | None = None, + with_add: bool = True, + ): + """ + Args: + z: + A [*, N, N, C_z] pair representation + mask: + A [*, N, N] pair mask + with_add: + If True, z is overwritten with (z + update). Otherwise, it is + overwritten with (update). + Returns: + A reference to the overwritten z + """ + + z_norm_in = triton_layernorm( + z, + self.layer_norm_in.weight, + self.layer_norm_in.bias, + self.layer_norm_in.eps, + ) + + p_g = triton_linear_fused( + z_norm_in, + self.linear_ab_g.weight, + self.linear_ab_g.bias, + apply_sigmoid=True, + ) + p = triton_linear_fused( + z_norm_in, + self.linear_ab_p.weight, + self.linear_ab_p.bias, + other=p_g, + mask=mask, + ) + a = p[..., : self.c_hidden] + b = p[..., self.c_hidden :] + + if self._outgoing: + a = permute_final_dims(a, (2, 0, 1)) + b = permute_final_dims(b, (2, 1, 0)) + else: + a = permute_final_dims(a, (2, 1, 0)) + b = permute_final_dims(b, (2, 0, 1)) + + if _inplace_chunk_size is not None: + for i in range(0, a.shape[-3], _inplace_chunk_size): + a_chunk = a[..., i : i + _inplace_chunk_size, :, :] + b_chunk = b[..., i : i + _inplace_chunk_size, :, :] + a[..., i : i + _inplace_chunk_size, :, :] = torch.einsum( + "...ij,...jk->...ik", a_chunk, b_chunk + ) + + x = a + else: + x = torch.einsum("...ij,...jk->...ik", a, b) + + x = permute_final_dims(x, (1, 2, 0)) + x = triton_layernorm( + x, + self.layer_norm_out.weight, + self.layer_norm_out.bias, + self.layer_norm_out.eps, + ) + x = triton_linear(x, self.linear_z.weight, self.linear_z.bias) + # Fused: sigmoid(linear(z_norm_in)) * x [+ z] + if with_add: + triton_linear_fused( + z_norm_in, + self.linear_g.weight, + self.linear_g.bias, + other=x, + add_tensor=z, + apply_sigmoid=True, + ) + else: + z[:] = triton_linear_fused( + z_norm_in, + self.linear_g.weight, + self.linear_g.bias, + other=x, + apply_sigmoid=True, + ) + + return z + def _inference_forward( self, z: torch.Tensor, mask: torch.Tensor | None = None, _inplace_chunk_size: int | None = None, with_add: bool = True, + use_triton_triangle_kernels: bool = False, ): """ Args: @@ -567,6 +1331,8 @@ def _inference_forward( with_add: If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update). + use_triton_triangle_kernels: + If True, uses Triton kernels. Returns: A reference to the overwritten z """ @@ -575,6 +1341,11 @@ def _inference_forward( mask = mask.unsqueeze(-1) + if use_triton_triangle_kernels: + return self._triton_inference_forward( + z, mask, _inplace_chunk_size, with_add + ) + def compute_projection_helper(pair, mask): p = self.linear_ab_g(pair) p.sigmoid_() @@ -611,6 +1382,7 @@ def forward( mask: torch.Tensor | None = None, inplace_safe: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, _add_with_inplace: bool = False, _inplace_chunk_size: int | None = 256, ) -> torch.Tensor: @@ -646,6 +1418,7 @@ def forward( mask, _inplace_chunk_size=_inplace_chunk_size, with_add=_add_with_inplace, + use_triton_triangle_kernels=use_triton_triangle_kernels, ) return x diff --git a/openfold3/core/model/primitives/attention.py b/openfold3/core/model/primitives/attention.py index 4842c16de..64a8df2d8 100644 --- a/openfold3/core/model/primitives/attention.py +++ b/openfold3/core/model/primitives/attention.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2025 NVIDIA Corporation # Copyright 2021 DeepMind Technologies Limited # @@ -47,6 +48,13 @@ if ds4s_is_installed: from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention +try: + from openfold3.core.kernels.triton.evoformer import TritonEvoformer +except ImportError: + TritonEvoformer = None + +TRITON_AVAILABLE = TritonEvoformer is not None + cueq_is_installed = is_cuequivariance_available() if cueq_is_installed: from cuequivariance_ops_torch.triangle_attention import ( @@ -71,6 +79,28 @@ def cueq_would_fall_back(n_token: int, hidden_dim: int, dtype: torch.dtype): DEFAULT_LMA_Q_CHUNK_SIZE = 1024 DEFAULT_LMA_KV_CHUNK_SIZE = 4096 +# Cache of all-zero pair-bias tensors used when the Triton kernel requires a pair_bias +# argument but the caller has none (e.g. MSA column attention). Keyed by +# (device_str, dtype, B, H, L) so tensors are never reused across incompatible shapes. +_ZERO_PAIR_BIAS_CACHE: dict = {} + + +def _get_zero_pair_bias( + device: torch.device, dtype: torch.dtype, B: int, H: int, L: int +) -> torch.Tensor: + """Return a cached [B, 1, H, L, L] zero tensor, allocating only on first call""" + key = (str(device), dtype, int(B), int(H), int(L)) + t = _ZERO_PAIR_BIAS_CACHE.get(key) + if t is None or t.shape != (B, 1, H, L, L): + t = torch.zeros((B, 1, H, L, L), device=device, dtype=dtype) + _ZERO_PAIR_BIAS_CACHE[key] = t + return t + + +def clear_zero_pair_bias_cache(): + """Clear the cached zero pair-bias tensors.""" + _ZERO_PAIR_BIAS_CACHE.clear() + @torch.jit.ignore def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: @@ -309,6 +339,7 @@ def forward( biases: list[torch.Tensor] | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE, lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE, @@ -329,6 +360,9 @@ def forward( use_cueq_triangle_kernels: whether to use cuequivariance triangle kernels. Mutually exclusive with use_lma + use_triton_triangle_kernels: + Whether to use Triton-based memory-efficient attention kernel. + Mutually exclusive with other kernel options. use_lma: Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a stock PyTorch @@ -364,8 +398,13 @@ def forward( if use_deepspeed_evo_attention and q_x.shape[-2] <= 16: use_deepspeed_evo_attention = False + if use_triton_triangle_kernels and q_x.shape[-2] <= 16: + use_triton_triangle_kernels = False + attn_options = [ - use_deepspeed_evo_attention or use_cueq_triangle_kernels, + use_deepspeed_evo_attention + or use_cueq_triangle_kernels + or use_triton_triangle_kernels, use_lma, use_high_precision, ] @@ -375,11 +414,15 @@ def forward( if biases is None: biases = [] - # DeepSpeed attention kernel and cuequivariance kernel apply scaling internally + # DeepSpeed, cuequivariance, and Triton kernels apply scaling internally q, k, v = self._prep_qkv( q_x, kv_x, - apply_scale=not (use_deepspeed_evo_attention or use_cueq_triangle_kernels), + apply_scale=not ( + use_deepspeed_evo_attention + or use_cueq_triangle_kernels + or use_triton_triangle_kernels + ), ) # cuequivariance kernel takes precedence over use_deepspeed_evo_attention @@ -398,6 +441,15 @@ def forward( "provide up to two bias terms" ) o = _deepspeed_evo_attn(q, k, v, biases) + elif use_triton_triangle_kernels: + if not TRITON_AVAILABLE or TritonEvoformer is None: + raise RuntimeError( + "Triton kernels requested (use_triton_triangle_kernels=True) " + "but openfold3.core.kernels.triton is not available. " + "Ensure the package is installed with Triton support." + ) + o = _triton_evo_attn(q, k, v, biases) + elif use_lma: biases = [ b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) @@ -561,6 +613,73 @@ def convert_dtype(x: torch.Tensor) -> torch.Tensor: return o +@torch.compiler.disable +def _triton_evo_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + biases: list[torch.Tensor], +): + """ + Compute attention using the Triton EvoformerAttention kernel. + + Args: + q: + [*, H, Q, C_hidden] query data + k: + [*, H, K, C_hidden] key data + v: + [*, H, V, C_hidden] value data + biases: + List of biases that broadcast to [*, H, Q, K] + """ + + def reshape_dims(x): + no_batch_dims = len(x.shape[:-3]) + if no_batch_dims < 2: + return x.reshape(*((1,) * (2 - no_batch_dims) + x.shape)) + if no_batch_dims > 2: + return x.reshape(*((x.shape[0], -1) + x.shape[-3:])) + return x + + # [*, Q/K, H, C_hidden] + q = q.transpose(-2, -3) + k = k.transpose(-2, -3) + v = v.transpose(-2, -3) + + # Reshape to [B, N_seq, N_res, H, C_hidden] as required by the kernel. + orig_shape = q.shape + if len(orig_shape[:-3]) != 2: + q = reshape_dims(q) + k = reshape_dims(k) + v = reshape_dims(v) + biases = [reshape_dims(b) for b in biases] + + # When there is no pair bias (e.g. MSA column attention), pass a zero tensor + # and set HAS_PAIR_BIAS=False so the kernel skips all pair_bias loads. + has_pair_bias = len(biases) == 2 + if not has_pair_bias: + Batch, N_seq, N_res, Head, Dim = q.shape + biases.append(_get_zero_pair_bias(q.device, q.dtype, Batch, Head, N_res)) + + # Kernel requires fp16 or bf16; cast if needed. + orig_dtype = q.dtype + if orig_dtype not in [torch.bfloat16, torch.float16]: + o = TritonEvoformer( + q.to(dtype=torch.bfloat16), + k.to(dtype=torch.bfloat16), + v.to(dtype=torch.bfloat16), + biases[0].to(dtype=torch.bfloat16), + biases[1].to(dtype=torch.bfloat16), + has_pair_bias, + ).to(dtype=orig_dtype) + else: + o = TritonEvoformer(q, k, v, biases[0], biases[1], has_pair_bias) + + o = o.reshape(orig_shape) + return o + + def _lma( q: torch.Tensor, k: torch.Tensor, diff --git a/openfold3/core/model/structure/diffusion_module.py b/openfold3/core/model/structure/diffusion_module.py index 5a3b65dd0..aba42592a 100644 --- a/openfold3/core/model/structure/diffusion_module.py +++ b/openfold3/core/model/structure/diffusion_module.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -176,6 +177,7 @@ def forward( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, use_high_precision_attention: bool = False, _mask_trans: bool = True, @@ -206,6 +208,8 @@ def forward( Inference-time subbatch size use_deepspeed_evo_attention: Whether to use DeepSpeed Evo Attention kernel + use_triton_triangle_kernels: + Whether to use Triton triangle attention kernel use_lma: Whether to use LMA use_high_precision_attention: @@ -248,6 +252,7 @@ def forward( mask=token_mask, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, use_high_precision_attention=use_high_precision_attention, _mask_trans=_mask_trans, @@ -324,6 +329,7 @@ def forward( chunk_size: int | None = None, use_deepspeed_evo_attention: bool = False, use_cueq_triangle_kernels: bool = False, + use_triton_triangle_kernels: bool = False, use_lma: bool = False, use_high_precision_attention: bool = False, _mask_trans: bool = True, @@ -348,6 +354,8 @@ def forward( Inference-time subbatch size use_deepspeed_evo_attention: Whether to use DeepSpeed Evo Attention kernel + use_triton_triangle_kernels: + Whether to use Triton triangle attention kernel use_lma: Whether to use LMA use_high_precision_attention: @@ -394,6 +402,7 @@ def forward( chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, use_lma=use_lma, use_high_precision_attention=use_high_precision_attention, _mask_trans=_mask_trans, diff --git a/openfold3/entry_points/import_utils.py b/openfold3/entry_points/import_utils.py index 7f3452fcc..6c2540121 100644 --- a/openfold3/entry_points/import_utils.py +++ b/openfold3/entry_points/import_utils.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,3 +31,7 @@ def _torch_gpu_setup(): ): # Gives a large speedup on Ampere-class GPUs torch.set_float32_matmul_precision("high") + + # On ROCm/HIP backends + if torch.cuda.is_available() and torch.version.hip is not None: + torch.backends.cuda.preferred_blas_library("cublas") diff --git a/openfold3/entry_points/validate_rocm.py b/openfold3/entry_points/validate_rocm.py new file mode 100644 index 000000000..e9fdf38aa --- /dev/null +++ b/openfold3/entry_points/validate_rocm.py @@ -0,0 +1,126 @@ +# Copyright 2026 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Validation script for AMD ROCm inference with OpenFold3. + +Run after installing openfold3 on a ROCm system to verify that the environment +is correctly configured for the Triton kernels: + + validate-openfold3-rocm +""" + +import sys + + +def _check(label: str, ok: bool, detail: str = "") -> bool: + status = "PASS" if ok else "FAIL" + line = f" [{status}] {label}" + if detail: + line += f": {detail}" + print(line) + return ok + + +def main() -> None: + print("OpenFold3 ROCm environment check\n") + all_ok = True + + # 1. PyTorch importable + try: + import torch + + torch_version = torch.__version__ + torch_ok = True + except ImportError: + torch_version = "not found" + torch_ok = False + all_ok &= _check("PyTorch installed", torch_ok, torch_version) + + if not torch_ok: + print("\nInstall PyTorch for ROCm first:") + print( + " pip install torch torchvision torchaudio" + " --index-url https://download.pytorch.org/whl/rocm7.2" + ) + sys.exit(1) + + # 2. ROCm / HIP build + hip_version = torch.version.hip + hip_ok = hip_version is not None + all_ok &= _check( + "PyTorch built with ROCm (HIP)", + hip_ok, + hip_version if hip_ok else "torch.version.hip is None — this is a CUDA build", + ) + + # 3. ROCm GPU visible + gpu_ok = torch.cuda.is_available() + device_name = torch.cuda.get_device_name(0) if gpu_ok else "none" + all_ok &= _check("ROCm GPU visible", gpu_ok, device_name) + + # 4. Triton importable + try: + import triton + + triton_version = triton.__version__ + triton_ok = True + except ImportError: + triton_version = "not found" + triton_ok = False + all_ok &= _check("Triton installed", triton_ok, triton_version) + + if not triton_ok: + print("\nTriton should be bundled with the ROCm PyTorch wheel.") + print("Re-install PyTorch for ROCm:") + print( + " pip install torch torchvision torchaudio" + " --index-url https://download.pytorch.org/whl/rocm7.2" + ) + sys.exit(1) + + # 5. Triton backend is HIP + try: + import triton.runtime + + backend = triton.runtime.driver.active.get_current_target().backend + hip_backend_ok = backend == "hip" + all_ok &= _check("Triton backend is HIP", hip_backend_ok, backend) + except Exception as e: + all_ok &= _check("Triton backend is HIP", False, str(e)) + + # 6. OpenFold3 Triton evoformer kernel loads + try: + from openfold3.core.kernels.triton.evoformer import TritonEvoformer + + kernel_ok = TritonEvoformer is not None + all_ok &= _check("Triton evoformer kernel loaded", kernel_ok) + except Exception as e: + all_ok &= _check("Triton evoformer kernel loaded", False, str(e)) + + # Summary + print() + if all_ok: + print("All checks passed. OpenFold3 ROCm inference is correctly configured.") + else: + print("One or more checks failed. See above for details.") + print( + "Installation instructions: " + "https://github.com/aqlaboratory/openfold-3/blob/main/docs/source/Installation.md" + ) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/openfold3/projects/of3_all_atom/config/model_config.py b/openfold3/projects/of3_all_atom/config/model_config.py index 11ba1dba0..68cd97a81 100644 --- a/openfold3/projects/of3_all_atom/config/model_config.py +++ b/openfold3/projects/of3_all_atom/config/model_config.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,11 +14,15 @@ # limitations under the License. import ml_collections as mlc +import torch from openfold3.projects.of3_all_atom.config import ( linear_init_config as lin_init, ) +# Detect AMD/ROCm hardware. On ROCm, use Triton kernels; on CUDA, use DeepSpeed. +_is_rocm = torch.version.hip is not None + # Hidden dimensions c_s = mlc.FieldReference(384, field_type=int) c_z = mlc.FieldReference(128, field_type=int) @@ -94,6 +99,9 @@ # exclusive with use_lma. "use_deepspeed_evo_attention": False, "use_cueq_triangle_kernels": False, + # Use Triton-based memory-efficient attention kernel. Mutually + # exclusive with use_deepspeed_evo_attention and use_lma. + "use_triton_triangle_kernels": False, # Use Staats & Rabe's low-memory attention algorithm. Mutually # exclusive with use_deepspeed_evo_attention. "use_lma": False, @@ -104,8 +112,9 @@ }, "eval": { "chunk_size": None, - "use_deepspeed_evo_attention": True, + "use_deepspeed_evo_attention": not _is_rocm, "use_cueq_triangle_kernels": False, + "use_triton_triangle_kernels": _is_rocm, "use_lma": False, "msa_module": { "swiglu_chunk_token_cutoff": None, diff --git a/openfold3/projects/of3_all_atom/model.py b/openfold3/projects/of3_all_atom/model.py index 9f7724f68..28a9b24b4 100644 --- a/openfold3/projects/of3_all_atom/model.py +++ b/openfold3/projects/of3_all_atom/model.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -239,6 +240,7 @@ def run_trunk( chunk_size=mode_mem_settings.chunk_size, _mask_trans=True, use_deepspeed_evo_attention=mode_mem_settings.use_deepspeed_evo_attention, + use_triton_triangle_kernels=mode_mem_settings.use_triton_triangle_kernels, use_cueq_triangle_kernels=mode_mem_settings.use_cueq_triangle_kernels, use_lma=mode_mem_settings.use_lma, inplace_safe=inplace_safe, @@ -269,6 +271,7 @@ def run_trunk( chunk_size=mode_mem_settings.chunk_size, transition_ckpt_chunk_size=transition_ckpt_chunk_size, use_deepspeed_evo_attention=mode_mem_settings.use_deepspeed_evo_attention, + use_triton_triangle_kernels=mode_mem_settings.use_triton_triangle_kernels, use_cueq_triangle_kernels=mode_mem_settings.use_cueq_triangle_kernels, use_lma=mode_mem_settings.use_lma, _mask_trans=True, @@ -284,6 +287,7 @@ def run_trunk( chunk_size=mode_mem_settings.chunk_size, transition_ckpt_chunk_size=transition_ckpt_chunk_size, use_deepspeed_evo_attention=mode_mem_settings.use_deepspeed_evo_attention, + use_triton_triangle_kernels=mode_mem_settings.use_triton_triangle_kernels, use_cueq_triangle_kernels=mode_mem_settings.use_cueq_triangle_kernels, use_lma=mode_mem_settings.use_lma, inplace_safe=inplace_safe, @@ -300,6 +304,7 @@ def run_trunk( pair_mask=pair_mask.to(dtype=s.dtype), chunk_size=mode_mem_settings.chunk_size, use_deepspeed_evo_attention=mode_mem_settings.use_deepspeed_evo_attention, + use_triton_triangle_kernels=mode_mem_settings.use_triton_triangle_kernels, use_cueq_triangle_kernels=mode_mem_settings.use_cueq_triangle_kernels, use_lma=mode_mem_settings.use_lma, inplace_safe=inplace_safe, @@ -388,6 +393,7 @@ def _rollout( use_conditioning=True, chunk_size=mode_mem_settings.chunk_size, use_deepspeed_evo_attention=mode_mem_settings.use_deepspeed_evo_attention, + use_triton_triangle_kernels=mode_mem_settings.use_triton_triangle_kernels, use_cueq_triangle_kernels=mode_mem_settings.use_cueq_triangle_kernels, use_lma=mode_mem_settings.use_lma, _mask_trans=True, @@ -412,6 +418,7 @@ def _rollout( use_zij_trunk_embedding=use_trunk_embedding, chunk_size=mode_mem_settings.chunk_size, use_deepspeed_evo_attention=mode_mem_settings.use_deepspeed_evo_attention, + use_triton_triangle_kernels=mode_mem_settings.use_triton_triangle_kernels, use_cueq_triangle_kernels=mode_mem_settings.use_cueq_triangle_kernels, use_lma=mode_mem_settings.use_lma, inplace_safe=inplace_safe, diff --git a/openfold3/tests/compare_utils.py b/openfold3/tests/compare_utils.py index 11e7bc640..c52682b37 100644 --- a/openfold3/tests/compare_utils.py +++ b/openfold3/tests/compare_utils.py @@ -20,14 +20,21 @@ from openfold3.core.kernels.cueq_utils import is_cuequivariance_available +def skip_if_rocm(): + is_rocm = torch.cuda.is_available() and torch.version.hip is not None + return unittest.skipIf(is_rocm, "Not supported on ROCm/HIP") + + def skip_unless_ds4s_installed(): deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None ds4s_is_installed = ( deepspeed_is_installed and importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None ) + is_rocm = torch.cuda.is_available() and torch.version.hip is not None return unittest.skipUnless( - ds4s_is_installed, "Requires DeepSpeed with version ≥ 0.10.4" + ds4s_is_installed and not is_rocm, + "Requires DeepSpeed with version ≥ 0.10.4 (not supported on ROCm/HIP)", ) diff --git a/openfold3/tests/conftest.py b/openfold3/tests/conftest.py index 76bd897b3..41b75a231 100644 --- a/openfold3/tests/conftest.py +++ b/openfold3/tests/conftest.py @@ -1,12 +1,20 @@ import biotite.setup_ccd import numpy as np import pytest +import torch from biotite.structure import AtomArray from openfold3.core.data.primitives.structure.component import BiotiteCCDWrapper from openfold3.setup_openfold import setup_biotite_ccd +@pytest.fixture(scope="session", autouse=True) +def rocm_blas_setup(): + """On ROCm/HIP backends, prefer rocBLAS over hipBLASLt.""" + if torch.cuda.is_available() and torch.version.hip is not None: + torch.backends.cuda.preferred_blas_library("cublas") + + @pytest.fixture def dummy_atom_array(): # Create dummy atom array diff --git a/openfold3/tests/test_kernels.py b/openfold3/tests/test_kernels.py index adb0bd719..226ba1760 100644 --- a/openfold3/tests/test_kernels.py +++ b/openfold3/tests/test_kernels.py @@ -1,4 +1,5 @@ # Copyright 2026 AlQuraishi Laboratory +# Copyright 2026 Advanced Micro Devices, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,6 +45,8 @@ torch.backends.cuda.matmul.allow_tf32 = True pytestmark = [pytest.mark.slow] +torch.backends.cuda.preferred_blas_library("cublas") + @compare_utils.skip_unless_cuda_available() class TestKernels(unittest.TestCase): @@ -51,6 +54,7 @@ def _compare_attn_kernel_forward( self, use_deepspeed_evo_attention=False, use_cueq_triangle_kernels=False, + use_triton_triangle_kernels=False, dtype=torch.float32, ): """Compare attention with and without using DeepSpeed Evoformer kernel.""" @@ -92,6 +96,7 @@ def _compare_attn_kernel_forward( biases=biases, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, ).cpu() err = torch.max(torch.abs(kernel_out - real_out)) @@ -129,10 +134,25 @@ def test_cueq_forward_bf16(self): dtype=torch.bfloat16, ) + @compare_utils.skip_unless_triton_installed() + def test_triton_forward_bf16(self): + self._compare_attn_kernel_forward( + use_triton_triangle_kernels=True, + dtype=torch.bfloat16, + ) + + @compare_utils.skip_unless_triton_installed() + def test_triton_forward_fp32(self): + self._compare_attn_kernel_forward( + use_triton_triangle_kernels=True, + dtype=torch.float32, + ) + def _compare_attn_kernel_backward( self, use_deepspeed_evo_attention=False, use_cueq_triangle_kernels=False, + use_triton_triangle_kernels=False, dtype=torch.float32, ): """ @@ -199,6 +219,7 @@ def init_attn(): biases=biases_repro, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, ) loss_repro = torch.mean(out_repro) loss_repro.backward() @@ -266,6 +287,13 @@ def test_cueq_backward_bf16(self): dtype=torch.bfloat16, ) + @compare_utils.skip_unless_triton_installed() + def test_triton_backward_bf16(self): + self._compare_attn_kernel_backward( + use_triton_triangle_kernels=True, + dtype=torch.bfloat16, + ) + @compare_utils.skip_unless_cueq_installed() def test_cueq_tri_mult_fwd(self): batch = consts.batch_size @@ -390,6 +418,7 @@ def _compare_pairformer( self, use_deepspeed_evo_attention=False, use_cueq_triangle_kernels=False, + use_triton_triangle_kernels=False, dtype=torch.float32, chunk_size=None, eps=2e-2, @@ -407,7 +436,9 @@ def _compare_pairformer( """ batch_size = consts.batch_size if chunk_size is not None and ( - use_deepspeed_evo_attention or use_cueq_triangle_kernels + use_deepspeed_evo_attention + or use_cueq_triangle_kernels + or use_triton_triangle_kernels ): # Chunk tuning is not supported with batch size > 1 for DeepSpeed kernel batch_size = 1 @@ -480,6 +511,7 @@ def _compare_pairformer( pair_mask=z_mask, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, chunk_size=chunk_size, ) out_repro_single_ds = F.layer_norm(out_repro_single_ds, (consts.c_s,)).cpu() @@ -555,6 +587,34 @@ def test_compare_pairformer_cueq_fp32_chunk(self): eps=4e-2, ) + @compare_utils.skip_unless_triton_installed() + def test_compare_pairformer_triton_bf16(self): + """Run Pairformer comparison test with Triton kernel and BF16 precision.""" + self._compare_pairformer( + use_triton_triangle_kernels=True, + dtype=torch.bfloat16, + eps=2e-2, + ) + + @compare_utils.skip_unless_triton_installed() + def test_compare_pairformer_triton_fp32(self): + """Run Pairformer comparison test with Triton kernel and FP32 precision.""" + self._compare_pairformer( + use_triton_triangle_kernels=True, + dtype=torch.float32, + eps=2e-2, + ) + + @compare_utils.skip_unless_triton_installed() + def test_compare_pairformer_triton_fp32_chunk(self): + """Run Pairformer comparison test with Triton kernel and chunk tuning enabled.""" + self._compare_pairformer( + use_triton_triangle_kernels=True, + dtype=torch.float32, + chunk_size=4, + eps=4e-2, + ) + def _compare_diffusion_transformer( self, use_deepspeed_evo_attention=False, @@ -655,6 +715,7 @@ def _compare_template_stack( self, use_deepspeed_evo_attention=False, use_cueq_triangle_kernels=False, + use_triton_triangle_kernels=False, dtype=torch.float32, chunk_size=None, ): @@ -724,6 +785,7 @@ def to_device(t): chunk_size=chunk_size, use_deepspeed_evo_attention=use_deepspeed_evo_attention, use_cueq_triangle_kernels=use_cueq_triangle_kernels, + use_triton_triangle_kernels=use_triton_triangle_kernels, ) compare_utils.assert_max_abs_diff_small(out_repro, out_repro_ds, 2e-2) @@ -778,6 +840,34 @@ def test_compare_template_stack_cueq_fp32_chunk(self): chunk_size=4, ) + @compare_utils.skip_unless_triton_installed() + def test_compare_template_stack_triton_fp32_chunk(self): + self._compare_template_stack( + use_deepspeed_evo_attention=False, + use_cueq_triangle_kernels=False, + use_triton_triangle_kernels=True, + dtype=torch.float32, + chunk_size=4, + ) + + @compare_utils.skip_unless_triton_installed() + def test_compare_template_stack_triton_fp32(self): + self._compare_template_stack( + use_deepspeed_evo_attention=False, + use_cueq_triangle_kernels=False, + use_triton_triangle_kernels=True, + dtype=torch.float32, + ) + + @compare_utils.skip_unless_triton_installed() + def test_compare_template_stack_triton_bf16(self): + self._compare_template_stack( + use_deepspeed_evo_attention=False, + use_cueq_triangle_kernels=False, + use_triton_triangle_kernels=True, + dtype=torch.bfloat16, + ) + if __name__ == "__main__": unittest.main() diff --git a/openfold3/tests/test_of3_model.py b/openfold3/tests/test_of3_model.py index 3064c6023..1dbf07a1e 100644 --- a/openfold3/tests/test_of3_model.py +++ b/openfold3/tests/test_of3_model.py @@ -36,6 +36,7 @@ def run_model( train=True, reduce_model_size=True, use_deepspeed_evo_attention=False, + use_triton_triangle_kernels=False, ): device = "cuda" if torch.cuda.is_available() else "cpu" @@ -57,6 +58,9 @@ def run_model( config.settings.memory.eval.use_deepspeed_evo_attention = ( use_deepspeed_evo_attention ) + + if use_triton_triangle_kernels: + config.settings.memory.eval.use_triton_triangle_kernels = True config.architecture.loss_module.diffusion.chunk_size = 16 of3 = OpenFold3AllAtom(config).to(device=device, dtype=dtype) @@ -163,6 +167,7 @@ def test_shape_small_kernels(self, dtype, model_phase): n_msa = 10 n_templ = 3 + is_rocm = torch.cuda.is_available() and torch.version.hip is not None is_train = model_phase == "train" self.run_model( batch_size=batch_size, @@ -172,7 +177,8 @@ def test_shape_small_kernels(self, dtype, model_phase): dtype=dtype, train=is_train, reduce_model_size=True, - use_deepspeed_evo_attention=True, + use_deepspeed_evo_attention=not is_rocm, + use_triton_triangle_kernels=is_rocm, ) @pytest.mark.slow @@ -187,6 +193,7 @@ def test_shape_large_eval(self, dtype): n_msa = 16384 n_templ = 4 + is_rocm = torch.cuda.is_available() and torch.version.hip is not None self.run_model( batch_size=batch_size, n_token=n_token, @@ -195,7 +202,8 @@ def test_shape_large_eval(self, dtype): dtype=dtype, train=False, reduce_model_size=False, - use_deepspeed_evo_attention=True, + use_deepspeed_evo_attention=not is_rocm, + use_triton_triangle_kernels=is_rocm, ) @compare_utils.skip_unless_triton_installed() @@ -206,6 +214,7 @@ def test_shape_large_bf16_train(self): n_msa = 16384 n_templ = 4 + is_rocm = torch.cuda.is_available() and torch.version.hip is not None self.run_model( batch_size=batch_size, n_token=n_token, @@ -214,5 +223,6 @@ def test_shape_large_bf16_train(self): dtype=torch.bfloat16, train=True, reduce_model_size=False, - use_deepspeed_evo_attention=True, + use_deepspeed_evo_attention=not is_rocm, + use_triton_triangle_kernels=is_rocm, ) diff --git a/pyproject.toml b/pyproject.toml index c428ef6cf..3c20708b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ Repository = "https://github.com/aqlaboratory/openfold-3" [project.scripts] run_openfold="openfold3.run_openfold:cli" setup_openfold="openfold3.setup_openfold:main" +validate-openfold3-rocm="openfold3.entry_points.validate_rocm:main" [dependency-groups] test = [