You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been building runtime inference monitoring tooling on MLX (tracking entropy and representation drift across long generation runs with mlx-lm models) and needed to capture hidden states from specific transformer layers during and after generation.
Since mlx.nn.Module doesn't have a register_forward_hook equivalent, the cleanest pattern I've found is to temporarily swap a layer with a thin wrapper:
import mlx.nn as nn
class CaptureWrapper(nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
self.captured = []
def __call__(self, *args, **kwargs):
out = self.layer(*args, **kwargs)
self.captured.append(out)
return out
# Usage (e.g., capturing layer 16 of a Qwen2.5 model loaded via mlx-lm):
original = model.model.layers[16]
wrapper = CaptureWrapper(original)
model.model.layers[16] = wrapper
try:
# run forward pass or generate
output = model(tokens)
mx.eval(output)
hidden_states = wrapper.captured[0]
finally:
model.model.layers[16] = original # always restore
A couple of gotchas I ran into:
mx.eval() returns None - never do result = mx.eval(tensor). Call mx.eval(tensor) for its side effect, then use tensor directly.
Multi-layer capture requires multiple wrappers and gets verbose quickly. Wrapping N layers means N swap/restore pairs in the finally block.
This works well enough for my use case (Qwen2.5-7B-4bit on M4 Pro, capturing from individual layers during generation). But it doesn't compose cleanly if you want to capture from several layers at once, or if you want to hook into submodules within a layer (e.g., just the attention output before the residual add).
I know mx.custom_function covers the backward/gradient case (ml-explore/mlx-examples#987), but for the forward-only case - capturing activations without needing gradients - a lightweight callback would be cleaner than swap-and-restore.
Has anyone found a better pattern for this? Or would a minimal forward hook interface on mlx.nn.Module be something the project would consider? Happy to hear if the wrapper approach is the intended way to go.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I've been building runtime inference monitoring tooling on MLX (tracking entropy and representation drift across long generation runs with mlx-lm models) and needed to capture hidden states from specific transformer layers during and after generation.
Since
mlx.nn.Moduledoesn't have aregister_forward_hookequivalent, the cleanest pattern I've found is to temporarily swap a layer with a thin wrapper:A couple of gotchas I ran into:
mx.eval()returnsNone- never doresult = mx.eval(tensor). Callmx.eval(tensor)for its side effect, then usetensordirectly.finallyblock.This works well enough for my use case (Qwen2.5-7B-4bit on M4 Pro, capturing from individual layers during generation). But it doesn't compose cleanly if you want to capture from several layers at once, or if you want to hook into submodules within a layer (e.g., just the attention output before the residual add).
I know
mx.custom_functioncovers the backward/gradient case (ml-explore/mlx-examples#987), but for the forward-only case - capturing activations without needing gradients - a lightweight callback would be cleaner than swap-and-restore.Has anyone found a better pattern for this? Or would a minimal forward hook interface on
mlx.nn.Modulebe something the project would consider? Happy to hear if the wrapper approach is the intended way to go.Beta Was this translation helpful? Give feedback.
All reactions