You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am currently porting Z-Image Turbo to MLX. This model is a Diffusion Transformer (DiT) similar to SD3, but it uses 16-channel latents (unlike the standard 4-channel VAE in SDXL), which results in significantly larger activation sizes.
I am running this on a MacBook Pro 14 M3pro + 18Gb , but I'm hitting a performance bottleneck that I suspect might be memory bandwidth bound.
I have uploaded the full implementation to my repository. You can check the full code and run it to reproduce
Profiling suggests the main cost comes from the Attention layer, specifically its unique 3D RoPE mechanism.
Unlike standard RoPE, this model requires splitting the query/key tensors into three chunks [32, 48, 48], applying different rotary embeddings based on Height/Width/Time axes, and then concatenating them back.
I suspected the Split -> RoPE -> Concat pattern was causing excessive memory read/writes.
What I have tried
mx.compile: The entire step function is compiled.
Fused RoPE Args: Instead of splitting the huge Q/K tensors (MB size), I refactored the code to pre-calculate and concatenate the RoPE angles (Args) first, and then apply cos/sin and rotation on the full tensor to avoid memory copies.
Here is the current implementation of my Attention layer:
classAttention(nn.Module):
def__init__(self, dim: int, nheads: int, rope_theta: float=256.0, eps: float=1e-5):
super().__init__()
# ... (init code) ...# Z-Image specific splitsself.dims= [32, 48, 48]
self.freqs_cache= {}
def_get_fused_args(self, positions):
""" Optimization attempt: Instead of splitting Q/K tensors, we pre-calculate and fuse the 'angles' (theta). """B, L, _=positions.shapeifLinself.freqs_cache:
freqs_tuple=self.freqs_cache[L]
else:
# ... (pre-compute freqs logic) ...freqs_tuple=freqs_list# Calculate angles for each section (H, W, T)pos_h=positions[..., 0].astype(mx.float32)
args_h=pos_h[..., None, None] *freqs_tuple[0][None, None, None, :]
pos_w=positions[..., 1].astype(mx.float32)
args_w=pos_w[..., None, None] *freqs_tuple[1][None, None, None, :]
pos_t=positions[..., 2].astype(mx.float32)
args_t=pos_t[..., None, None] *freqs_tuple[2][None, None, None, :]
# Fuse angles (Concatenate Args, NOT the heavy Tensors)returnmx.concatenate([args_h, args_w, args_t], axis=-1)
def__call__(self, x, mask=None, positions=None):
B, L, D=x.shape# ... (Linear Projections) ...ifpositionsisnotNone:
# 1. Get fused anglesargs=self._get_fused_args(positions)
# 2. Compute Sin/Cos once for the whole chunkcos=mx.cos(args)
sin=mx.sin(args)
# 3. Rotate in-place (avoiding explicit split/concat of Q/K)q1=q[..., 0::2]
q2=q[..., 1::2]
q=mx.stack([q1*cos-q2*sin, q1*sin+q2*cos], axis=-1).reshape(B, L, self.nheads, self.head_dim)
# ... (Same for K) ...# ... (Attention & Output) ...returnself.to_out(output)
Despite this "Fused Args" optimization, the speed remains around 15s/step.
My Questions
Memory Bandwidth Bound: Given the 16-channel latents at 1024px (which is 4x larger than SDXL's 4 channels), is ~15s/step simply the physical limit of M3's memory bandwidth? Or should I expect it to be faster?
Kernel Fusion: Does mx.compile effectively handle this kind of partial rotation logic? I am wondering if I'm missing any specific MLX primitives or patterns that would help the compiler fuse these operations better.
Any insights or advice would be greatly appreciated. Thanks for the great work on MLX!
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, MLX Community,
I am currently porting Z-Image Turbo to MLX. This model is a Diffusion Transformer (DiT) similar to SD3, but it uses 16-channel latents (unlike the standard 4-channel VAE in SDXL), which results in significantly larger activation sizes.
I am running this on a MacBook Pro 14 M3pro + 18Gb , but I'm hitting a performance bottleneck that I suspect might be memory bandwidth bound.
I have uploaded the full implementation to my repository. You can check the full code and run it to reproduce
https://github.com/uqer1244/MLX_z-image
Context & Performance
(1, 16, 128, 128)latentsbfloat16The Bottleneck: 3D RoPE with Split/Concat
Profiling suggests the main cost comes from the Attention layer, specifically its unique 3D RoPE mechanism.
Unlike standard RoPE, this model requires splitting the query/key tensors into three chunks
[32, 48, 48], applying different rotary embeddings based on Height/Width/Time axes, and then concatenating them back.I suspected the
Split->RoPE->Concatpattern was causing excessive memory read/writes.What I have tried
mx.compile: The entire step function is compiled.cos/sinand rotation on the full tensor to avoid memory copies.Here is the current implementation of my Attention layer:
Despite this "Fused Args" optimization, the speed remains around 15s/step.
My Questions
mx.compileeffectively handle this kind of partial rotation logic? I am wondering if I'm missing any specific MLX primitives or patterns that would help the compiler fuse these operations better.Any insights or advice would be greatly appreciated. Thanks for the great work on MLX!
Beta Was this translation helpful? Give feedback.
All reactions