Skip to content

Commit 9623bfe

Browse files
committed
mapping error resolved with test check
1 parent bc75bbc commit 9623bfe

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

mapping_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
from unittest.mock import MagicMock
3+
4+
# Import the classes from your local version
5+
from transformers import KernelConfig
6+
7+
8+
def test_fix_on_mac():
9+
print("Testing KernelConfig Fix")
10+
kernel_mapping = {
11+
"RMSNorm": {
12+
"cuda": "kernels-community/layer_norm:LlamaRMSNorm",
13+
"rocm": "kernels-community/layer_norm:LlamaRMSNorm",
14+
}
15+
}
16+
17+
# 3. Create the config
18+
kernel_config = KernelConfig(kernel_mapping)
19+
20+
# 4. Create a MOCK model
21+
# We pretend this is a model on a CUDA device so we don't need the real Llama model
22+
mock_model = MagicMock()
23+
mock_model.training = False
24+
25+
# Mock the parameter device to return 'cuda'
26+
mock_param = MagicMock()
27+
mock_param.device.type = "cuda"
28+
mock_model.parameters.return_value = iter([mock_param])
29+
30+
# Mock named_modules to register the layer name "RMSNorm"
31+
mock_layer = MagicMock()
32+
mock_layer.kernel_layer_name = "RMSNorm"
33+
mock_model.named_modules.return_value = [("layers.0", mock_layer)]
34+
35+
print("Simulating model load...")
36+
37+
# 5. Trigger the logic you fixed
38+
try:
39+
kernel_config.create_compatible_mapping(mock_model)
40+
except Exception as e:
41+
print(f"Execution crashed: {e}")
42+
return
43+
44+
# 6. Verify the result
45+
result_mapping = kernel_config.kernel_mapping
46+
47+
print("\n--- Result ---")
48+
if "RMSNorm" in result_mapping:
49+
backends = result_mapping["RMSNorm"].keys()
50+
print(f"Registered Backends: {list(backends)}")
51+
52+
if "cuda" in backends and "rocm" not in backends:
53+
print("PASS: The fix worked! ROCm was ignored, preserving CUDA.")
54+
elif "rocm" in backends:
55+
print("FAIL: ROCm is present. It overwrote CUDA (The bug is still there).")
56+
else:
57+
print("FAIL: Mapping is empty.")
58+
else:
59+
print("FAIL: RMSNorm not found in mapping.")
60+
61+
62+
if __name__ == "__main__":
63+
test_fix_on_mac()

src/transformers/utils/kernel_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def create_compatible_mapping(self, model, compile=False):
208208
from kernels import Mode
209209

210210
compatible_mapping = {}
211+
current_device = infer_device(model)
211212
for layer_name, kernel in self.kernel_mapping.items():
212213
# Infer Mode: use Mode.TRAINING if model is training, else use Mode.INFERENCE
213214
mode = Mode.TRAINING if model.training else Mode.INFERENCE
@@ -216,10 +217,11 @@ def create_compatible_mapping(self, model, compile=False):
216217

217218
if isinstance(kernel, str):
218219
repo_name = kernel
219-
device = infer_device(model)
220220
add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping)
221221
elif isinstance(kernel, dict):
222222
for device, repo_name in kernel.items():
223+
if device != current_device:
224+
continue
223225
add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping)
224226

225227
self.kernel_mapping = compatible_mapping

0 commit comments

Comments
 (0)