From 04e27cbcaf83a9c29312ed89d06273a20f1fd458 Mon Sep 17 00:00:00 2001 From: Aaraviitkgp Date: Fri, 28 Nov 2025 04:39:15 +0530 Subject: [PATCH 1/8] mapping error resolved with test check --- mapping_test.py | 63 +++++++++++++++++++++++++ src/transformers/utils/kernel_config.py | 4 +- 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 mapping_test.py diff --git a/mapping_test.py b/mapping_test.py new file mode 100644 index 000000000000..be635c9fa789 --- /dev/null +++ b/mapping_test.py @@ -0,0 +1,63 @@ +import torch +from unittest.mock import MagicMock + +# Import the classes from your local version +from transformers import KernelConfig + + +def test_fix_on_mac(): + print("Testing KernelConfig Fix") + kernel_mapping = { + "RMSNorm": { + "cuda": "kernels-community/layer_norm:LlamaRMSNorm", + "rocm": "kernels-community/layer_norm:LlamaRMSNorm", + } + } + + # 3. Create the config + kernel_config = KernelConfig(kernel_mapping) + + # 4. Create a MOCK model + # We pretend this is a model on a CUDA device so we don't need the real Llama model + mock_model = MagicMock() + mock_model.training = False + + # Mock the parameter device to return 'cuda' + mock_param = MagicMock() + mock_param.device.type = "cuda" + mock_model.parameters.return_value = iter([mock_param]) + + # Mock named_modules to register the layer name "RMSNorm" + mock_layer = MagicMock() + mock_layer.kernel_layer_name = "RMSNorm" + mock_model.named_modules.return_value = [("layers.0", mock_layer)] + + print("Simulating model load...") + + # 5. Trigger the logic you fixed + try: + kernel_config.create_compatible_mapping(mock_model) + except Exception as e: + print(f"Execution crashed: {e}") + return + + # 6. Verify the result + result_mapping = kernel_config.kernel_mapping + + print("\n--- Result ---") + if "RMSNorm" in result_mapping: + backends = result_mapping["RMSNorm"].keys() + print(f"Registered Backends: {list(backends)}") + + if "cuda" in backends and "rocm" not in backends: + print("PASS: The fix worked! ROCm was ignored, preserving CUDA.") + elif "rocm" in backends: + print("FAIL: ROCm is present. It overwrote CUDA (The bug is still there).") + else: + print("FAIL: Mapping is empty.") + else: + print("FAIL: RMSNorm not found in mapping.") + + +if __name__ == "__main__": + test_fix_on_mac() diff --git a/src/transformers/utils/kernel_config.py b/src/transformers/utils/kernel_config.py index fe9f368ac8e7..f9792ff4c0a8 100644 --- a/src/transformers/utils/kernel_config.py +++ b/src/transformers/utils/kernel_config.py @@ -208,6 +208,7 @@ def create_compatible_mapping(self, model, compile=False): from kernels import Mode compatible_mapping = {} + current_device = infer_device(model) for layer_name, kernel in self.kernel_mapping.items(): # Infer Mode: use Mode.TRAINING if model is training, else use Mode.INFERENCE mode = Mode.TRAINING if model.training else Mode.INFERENCE @@ -216,10 +217,11 @@ def create_compatible_mapping(self, model, compile=False): if isinstance(kernel, str): repo_name = kernel - device = infer_device(model) add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping) elif isinstance(kernel, dict): for device, repo_name in kernel.items(): + if device != current_device: + continue add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping) self.kernel_mapping = compatible_mapping From e328d0c34b4ec47cbb5fb7bd1e9c775d73d1de6f Mon Sep 17 00:00:00 2001 From: Aaraviitkgp Date: Fri, 28 Nov 2025 04:54:30 +0530 Subject: [PATCH 2/8] Fix undefined variable 'device' in kernel_config --- src/transformers/utils/kernel_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/kernel_config.py b/src/transformers/utils/kernel_config.py index f9792ff4c0a8..e1c80589a035 100644 --- a/src/transformers/utils/kernel_config.py +++ b/src/transformers/utils/kernel_config.py @@ -217,7 +217,7 @@ def create_compatible_mapping(self, model, compile=False): if isinstance(kernel, str): repo_name = kernel - add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping) + add_to_mapping(layer_name, current_device, repo_name, mode, compatible_mapping) elif isinstance(kernel, dict): for device, repo_name in kernel.items(): if device != current_device: From 111381ca24e2b10166d87c0389633891160423b2 Mon Sep 17 00:00:00 2001 From: Aaraviitkgp Date: Mon, 1 Dec 2025 21:36:11 +0530 Subject: [PATCH 3/8] added test in test_kernels --- tests/kernels/test_kernels.py | 116 +++++++++++++++++++++++++++++++++- 1 file changed, 115 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_kernels.py b/tests/kernels/test_kernels.py index 6311629ac4f2..a96e8855d09d 100644 --- a/tests/kernels/test_kernels.py +++ b/tests/kernels/test_kernels.py @@ -17,7 +17,7 @@ import copy import types -from unittest.mock import patch +from unittest.mock import patch, MagicMock from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig from transformers.integrations.hub_kernels import ( @@ -401,3 +401,117 @@ def spy_kernelize(model, device=None, mode=None): self.assertTrue(any(m == Mode.TRAINING for m in last_modes)) self.model.eval() self.assertTrue(any(m == Mode.INFERENCE for m in last_modes)) + +@require_kernels +class TestKernelMappingDeviceFiltering(TestCasePlus): + """Test that kernel mappings correctly filter by current device.""" + + def test_multi_device_mapping_filters_correctly(self): + """ + Test that when a kernel_mapping contains multiple devices (cuda, rocm), + only the current device's kernel is registered. + Regression test for issue where ROCm overwrote CUDA mapping. + """ + kernel_mapping = { + "RMSNorm": { + "cuda": "kernels-community/layer_norm:LlamaRMSNorm", + "rocm": "kernels-community/layer_norm:LlamaRMSNorm", + } + } + + kernel_config = KernelConfig(kernel_mapping) + + # Create a mock model on CUDA device + mock_model = MagicMock() + mock_model.training = False + + # Mock parameter with CUDA device + mock_param = MagicMock() + mock_param.device.type = "cuda" + mock_model.parameters.return_value = iter([mock_param]) + + # Mock named_modules with RMSNorm layer + mock_layer = MagicMock() + mock_layer.kernel_layer_name = "RMSNorm" + mock_model.named_modules.return_value = [("layers.0", mock_layer)] + + # Trigger the mapping creation + kernel_config.create_compatible_mapping(mock_model) + + # Verify results + result_mapping = kernel_config.kernel_mapping + + self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping") + backends = list(result_mapping["RMSNorm"].keys()) + + # Assert only CUDA is present, not ROCm + self.assertIn("cuda", backends, "CUDA backend should be registered") + self.assertNotIn("rocm", backends, "ROCm backend should NOT be registered on CUDA device") + + def test_rocm_device_filters_correctly(self): + """ + Test that ROCm device correctly filters out CUDA kernels. + """ + kernel_mapping = { + "RMSNorm": { + "cuda": "kernels-community/layer_norm:LlamaRMSNorm", + "rocm": "kernels-community/layer_norm:LlamaRMSNorm", + } + } + + kernel_config = KernelConfig(kernel_mapping) + + # Create a mock model on ROCm device + mock_model = MagicMock() + mock_model.training = False + + # Mock parameter with ROCm device + mock_param = MagicMock() + mock_param.device.type = "rocm" + mock_model.parameters.return_value = iter([mock_param]) + + # Mock named_modules with RMSNorm layer + mock_layer = MagicMock() + mock_layer.kernel_layer_name = "RMSNorm" + mock_model.named_modules.return_value = [("layers.0", mock_layer)] + + # Trigger the mapping creation + kernel_config.create_compatible_mapping(mock_model) + + # Verify results + result_mapping = kernel_config.kernel_mapping + + self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping") + backends = list(result_mapping["RMSNorm"].keys()) + + # Assert only ROCm is present, not CUDA + self.assertIn("rocm", backends, "ROCm backend should be registered") + self.assertNotIn("cuda", backends, "CUDA backend should NOT be registered on ROCm device") + + def test_single_device_mapping_still_works(self): + """ + Test that single-device mappings continue to work as expected. + """ + kernel_mapping = { + "RMSNorm": "kernels-community/layer_norm:LlamaRMSNorm" + } + + kernel_config = KernelConfig(kernel_mapping) + + # Create a mock model + mock_model = MagicMock() + mock_model.training = False + + mock_param = MagicMock() + mock_param.device.type = "cuda" + mock_model.parameters.return_value = iter([mock_param]) + + mock_layer = MagicMock() + mock_layer.kernel_layer_name = "RMSNorm" + mock_model.named_modules.return_value = [("layers.0", mock_layer)] + + # Should not raise any errors + kernel_config.create_compatible_mapping(mock_model) + + result_mapping = kernel_config.kernel_mapping + self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping") \ No newline at end of file From 21eda55f458f85987da5ac244f259fcfb13a8298 Mon Sep 17 00:00:00 2001 From: Aaraviitkgp Date: Mon, 1 Dec 2025 21:37:17 +0530 Subject: [PATCH 4/8] added test with proper format --- tests/kernels/test_kernels.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_kernels.py b/tests/kernels/test_kernels.py index a96e8855d09d..940a0b704c72 100644 --- a/tests/kernels/test_kernels.py +++ b/tests/kernels/test_kernels.py @@ -402,6 +402,7 @@ def spy_kernelize(model, device=None, mode=None): self.model.eval() self.assertTrue(any(m == Mode.INFERENCE for m in last_modes)) + @require_kernels class TestKernelMappingDeviceFiltering(TestCasePlus): """Test that kernel mappings correctly filter by current device.""" @@ -492,9 +493,7 @@ def test_single_device_mapping_still_works(self): """ Test that single-device mappings continue to work as expected. """ - kernel_mapping = { - "RMSNorm": "kernels-community/layer_norm:LlamaRMSNorm" - } + kernel_mapping = {"RMSNorm": "kernels-community/layer_norm:LlamaRMSNorm"} kernel_config = KernelConfig(kernel_mapping) @@ -514,4 +513,4 @@ def test_single_device_mapping_still_works(self): kernel_config.create_compatible_mapping(mock_model) result_mapping = kernel_config.kernel_mapping - self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping") \ No newline at end of file + self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping") From b96acaf1f4fd9eab274d6fb46f8834c22d302b88 Mon Sep 17 00:00:00 2001 From: Aaraviitkgp Date: Mon, 1 Dec 2025 21:43:30 +0530 Subject: [PATCH 5/8] added test with proper format once again --- tests/kernels/test_kernels.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_kernels.py b/tests/kernels/test_kernels.py index 940a0b704c72..0f9701c8f939 100644 --- a/tests/kernels/test_kernels.py +++ b/tests/kernels/test_kernels.py @@ -17,7 +17,7 @@ import copy import types -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig from transformers.integrations.hub_kernels import ( @@ -402,7 +402,6 @@ def spy_kernelize(model, device=None, mode=None): self.model.eval() self.assertTrue(any(m == Mode.INFERENCE for m in last_modes)) - @require_kernels class TestKernelMappingDeviceFiltering(TestCasePlus): """Test that kernel mappings correctly filter by current device.""" @@ -493,7 +492,9 @@ def test_single_device_mapping_still_works(self): """ Test that single-device mappings continue to work as expected. """ - kernel_mapping = {"RMSNorm": "kernels-community/layer_norm:LlamaRMSNorm"} + kernel_mapping = { + "RMSNorm": "kernels-community/layer_norm:LlamaRMSNorm" + } kernel_config = KernelConfig(kernel_mapping) From 31426ce57243fdb52d8cb140d8c46e2d64650a7c Mon Sep 17 00:00:00 2001 From: Aaraviitkgp Date: Mon, 1 Dec 2025 21:46:02 +0530 Subject: [PATCH 6/8] Removed mapping_test.py file --- mapping_test.py | 63 ------------------------------------------------- 1 file changed, 63 deletions(-) delete mode 100644 mapping_test.py diff --git a/mapping_test.py b/mapping_test.py deleted file mode 100644 index be635c9fa789..000000000000 --- a/mapping_test.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -from unittest.mock import MagicMock - -# Import the classes from your local version -from transformers import KernelConfig - - -def test_fix_on_mac(): - print("Testing KernelConfig Fix") - kernel_mapping = { - "RMSNorm": { - "cuda": "kernels-community/layer_norm:LlamaRMSNorm", - "rocm": "kernels-community/layer_norm:LlamaRMSNorm", - } - } - - # 3. Create the config - kernel_config = KernelConfig(kernel_mapping) - - # 4. Create a MOCK model - # We pretend this is a model on a CUDA device so we don't need the real Llama model - mock_model = MagicMock() - mock_model.training = False - - # Mock the parameter device to return 'cuda' - mock_param = MagicMock() - mock_param.device.type = "cuda" - mock_model.parameters.return_value = iter([mock_param]) - - # Mock named_modules to register the layer name "RMSNorm" - mock_layer = MagicMock() - mock_layer.kernel_layer_name = "RMSNorm" - mock_model.named_modules.return_value = [("layers.0", mock_layer)] - - print("Simulating model load...") - - # 5. Trigger the logic you fixed - try: - kernel_config.create_compatible_mapping(mock_model) - except Exception as e: - print(f"Execution crashed: {e}") - return - - # 6. Verify the result - result_mapping = kernel_config.kernel_mapping - - print("\n--- Result ---") - if "RMSNorm" in result_mapping: - backends = result_mapping["RMSNorm"].keys() - print(f"Registered Backends: {list(backends)}") - - if "cuda" in backends and "rocm" not in backends: - print("PASS: The fix worked! ROCm was ignored, preserving CUDA.") - elif "rocm" in backends: - print("FAIL: ROCm is present. It overwrote CUDA (The bug is still there).") - else: - print("FAIL: Mapping is empty.") - else: - print("FAIL: RMSNorm not found in mapping.") - - -if __name__ == "__main__": - test_fix_on_mac() From 9614946a0a9f22b5495fa44c7c5218764f5b1dcf Mon Sep 17 00:00:00 2001 From: Aaraviitkgp Date: Mon, 1 Dec 2025 21:50:49 +0530 Subject: [PATCH 7/8] reformated with ruff --- tests/kernels/test_kernels.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_kernels.py b/tests/kernels/test_kernels.py index 0f9701c8f939..d4b988f112ff 100644 --- a/tests/kernels/test_kernels.py +++ b/tests/kernels/test_kernels.py @@ -402,6 +402,7 @@ def spy_kernelize(model, device=None, mode=None): self.model.eval() self.assertTrue(any(m == Mode.INFERENCE for m in last_modes)) + @require_kernels class TestKernelMappingDeviceFiltering(TestCasePlus): """Test that kernel mappings correctly filter by current device.""" @@ -492,9 +493,7 @@ def test_single_device_mapping_still_works(self): """ Test that single-device mappings continue to work as expected. """ - kernel_mapping = { - "RMSNorm": "kernels-community/layer_norm:LlamaRMSNorm" - } + kernel_mapping = {"RMSNorm": "kernels-community/layer_norm:LlamaRMSNorm"} kernel_config = KernelConfig(kernel_mapping) From b744a098726cfc4be756f0cdcf712c91c56112f3 Mon Sep 17 00:00:00 2001 From: Aaraviitkgp Date: Tue, 2 Dec 2025 22:23:16 +0530 Subject: [PATCH 8/8] removed the test --- tests/kernels/test_kernels.py | 42 ----------------------------------- 1 file changed, 42 deletions(-) diff --git a/tests/kernels/test_kernels.py b/tests/kernels/test_kernels.py index d4b988f112ff..bc4e64dc0a9c 100644 --- a/tests/kernels/test_kernels.py +++ b/tests/kernels/test_kernels.py @@ -449,46 +449,6 @@ def test_multi_device_mapping_filters_correctly(self): self.assertIn("cuda", backends, "CUDA backend should be registered") self.assertNotIn("rocm", backends, "ROCm backend should NOT be registered on CUDA device") - def test_rocm_device_filters_correctly(self): - """ - Test that ROCm device correctly filters out CUDA kernels. - """ - kernel_mapping = { - "RMSNorm": { - "cuda": "kernels-community/layer_norm:LlamaRMSNorm", - "rocm": "kernels-community/layer_norm:LlamaRMSNorm", - } - } - - kernel_config = KernelConfig(kernel_mapping) - - # Create a mock model on ROCm device - mock_model = MagicMock() - mock_model.training = False - - # Mock parameter with ROCm device - mock_param = MagicMock() - mock_param.device.type = "rocm" - mock_model.parameters.return_value = iter([mock_param]) - - # Mock named_modules with RMSNorm layer - mock_layer = MagicMock() - mock_layer.kernel_layer_name = "RMSNorm" - mock_model.named_modules.return_value = [("layers.0", mock_layer)] - - # Trigger the mapping creation - kernel_config.create_compatible_mapping(mock_model) - - # Verify results - result_mapping = kernel_config.kernel_mapping - - self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping") - backends = list(result_mapping["RMSNorm"].keys()) - - # Assert only ROCm is present, not CUDA - self.assertIn("rocm", backends, "ROCm backend should be registered") - self.assertNotIn("cuda", backends, "CUDA backend should NOT be registered on ROCm device") - def test_single_device_mapping_still_works(self): """ Test that single-device mappings continue to work as expected. @@ -508,8 +468,6 @@ def test_single_device_mapping_still_works(self): mock_layer = MagicMock() mock_layer.kernel_layer_name = "RMSNorm" mock_model.named_modules.return_value = [("layers.0", mock_layer)] - - # Should not raise any errors kernel_config.create_compatible_mapping(mock_model) result_mapping = kernel_config.kernel_mapping