-
Notifications
You must be signed in to change notification settings - Fork 516
Open
Description
I'm currently trying to optimize a model that uses deepseekv3 style MLA + dense MLP (self.down_proj(swiglu(self.gate_proj(x), self.up_proj(x)))) I have little knowledge about Metal programming and profiling.
Following the metal debugger page on mlx docs (https://ml-explore.github.io/mlx/build/html/dev/metal_debugger.html) and created a simple script to get the gputrace:
import mlx.core as mx
from mlx_lm.utils import load
# Load the model
model_path = "/Users/molly/Youtu-VL-4B-Instruct-mlx-int4"
tokenizer_config = {"trust_remote_code": True}
model, processor = load(model_path, tokenizer_config)
# warm up
for _ in range(5):
output = model(
inputs=[[0]]
)
mx.eval(output)
trace_file = "mlx_trace.gputrace"
# trace one forward pass
mx.metal.start_capture(trace_file)
output = model(
inputs=[[0]]
)
mx.eval(output)
print(output)
mx.metal.stop_capture()
It seems that there's always a ~260us gap every two layers with the GPU doing nothing (no ALU or Memory utilization). What's causing this kind of gap?
Any suggestions on how I can get better at optimizing Metal are also welcomed!
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels