|
| 1 | +import torch |
| 2 | +from unittest.mock import MagicMock |
| 3 | + |
| 4 | +# Import the classes from your local version |
| 5 | +from transformers import KernelConfig |
| 6 | + |
| 7 | + |
| 8 | +def test_fix_on_mac(): |
| 9 | + print("Testing KernelConfig Fix") |
| 10 | + kernel_mapping = { |
| 11 | + "RMSNorm": { |
| 12 | + "cuda": "kernels-community/layer_norm:LlamaRMSNorm", |
| 13 | + "rocm": "kernels-community/layer_norm:LlamaRMSNorm", |
| 14 | + } |
| 15 | + } |
| 16 | + |
| 17 | + # 3. Create the config |
| 18 | + kernel_config = KernelConfig(kernel_mapping) |
| 19 | + |
| 20 | + # 4. Create a MOCK model |
| 21 | + # We pretend this is a model on a CUDA device so we don't need the real Llama model |
| 22 | + mock_model = MagicMock() |
| 23 | + mock_model.training = False |
| 24 | + |
| 25 | + # Mock the parameter device to return 'cuda' |
| 26 | + mock_param = MagicMock() |
| 27 | + mock_param.device.type = "cuda" |
| 28 | + mock_model.parameters.return_value = iter([mock_param]) |
| 29 | + |
| 30 | + # Mock named_modules to register the layer name "RMSNorm" |
| 31 | + mock_layer = MagicMock() |
| 32 | + mock_layer.kernel_layer_name = "RMSNorm" |
| 33 | + mock_model.named_modules.return_value = [("layers.0", mock_layer)] |
| 34 | + |
| 35 | + print("Simulating model load...") |
| 36 | + |
| 37 | + # 5. Trigger the logic you fixed |
| 38 | + try: |
| 39 | + kernel_config.create_compatible_mapping(mock_model) |
| 40 | + except Exception as e: |
| 41 | + print(f"Execution crashed: {e}") |
| 42 | + return |
| 43 | + |
| 44 | + # 6. Verify the result |
| 45 | + result_mapping = kernel_config.kernel_mapping |
| 46 | + |
| 47 | + print("\n--- Result ---") |
| 48 | + if "RMSNorm" in result_mapping: |
| 49 | + backends = result_mapping["RMSNorm"].keys() |
| 50 | + print(f"Registered Backends: {list(backends)}") |
| 51 | + |
| 52 | + if "cuda" in backends and "rocm" not in backends: |
| 53 | + print("PASS: The fix worked! ROCm was ignored, preserving CUDA.") |
| 54 | + elif "rocm" in backends: |
| 55 | + print("FAIL: ROCm is present. It overwrote CUDA (The bug is still there).") |
| 56 | + else: |
| 57 | + print("FAIL: Mapping is empty.") |
| 58 | + else: |
| 59 | + print("FAIL: RMSNorm not found in mapping.") |
| 60 | + |
| 61 | + |
| 62 | +if __name__ == "__main__": |
| 63 | + test_fix_on_mac() |
0 commit comments