diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/triangle.py b/cuequivariance_torch/cuequivariance_torch/primitives/triangle.py index 2f53e98..dab3abf 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/triangle.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/triangle.py @@ -39,7 +39,6 @@ def triangle_attention( mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, return_aux: bool = False, - dim_order: Optional[Tuple[int, int, int, int, int]] = None, ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r""" Triangle Attention @@ -65,12 +64,6 @@ def triangle_attention( scale (float, optional): Float scale for q (s in the equation). If None, value 1/sqrt(d) is used. return_aux (bool): If True, two auxiliary tensors are returned along with the result. Defaults to False. - dim_order (tuple of 5 ints, optional): Permutation of (0,1,2,3,4) specifying how to - reorder the axes of q/k/v/bias from the user's layout to the kernel's [B,N,H,Q,D] - layout. This is an O(1) metadata-only permute (no data copy) and incurs zero - overhead when the resulting tensor already satisfies sm100f TMA alignment - constraints. Example: ``dim_order=(0,1,3,2,4)`` for tensors stored as - [B,N,Q,H,D] (H/Q swapped). Defaults to None (no reordering). Note: - B: batch size @@ -139,7 +132,7 @@ def triangle_attention( "Error importing triangle_attention from cuequivariance_ops_torch." ) else: - return f(q, k, v, bias, mask, scale, return_aux, dim_order) + return f(q, k, v, bias, mask, scale, return_aux) def triangle_multiplicative_update(