diff --git a/src/transformers/utils/kernel_config.py b/src/transformers/utils/kernel_config.py index fe9f368ac8e7..e1c80589a035 100644 --- a/src/transformers/utils/kernel_config.py +++ b/src/transformers/utils/kernel_config.py @@ -208,6 +208,7 @@ def create_compatible_mapping(self, model, compile=False): from kernels import Mode compatible_mapping = {} + current_device = infer_device(model) for layer_name, kernel in self.kernel_mapping.items(): # Infer Mode: use Mode.TRAINING if model is training, else use Mode.INFERENCE mode = Mode.TRAINING if model.training else Mode.INFERENCE @@ -216,10 +217,11 @@ def create_compatible_mapping(self, model, compile=False): if isinstance(kernel, str): repo_name = kernel - device = infer_device(model) - add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping) + add_to_mapping(layer_name, current_device, repo_name, mode, compatible_mapping) elif isinstance(kernel, dict): for device, repo_name in kernel.items(): + if device != current_device: + continue add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping) self.kernel_mapping = compatible_mapping diff --git a/tests/kernels/test_kernels.py b/tests/kernels/test_kernels.py index 6311629ac4f2..bc4e64dc0a9c 100644 --- a/tests/kernels/test_kernels.py +++ b/tests/kernels/test_kernels.py @@ -17,7 +17,7 @@ import copy import types -from unittest.mock import patch +from unittest.mock import MagicMock, patch from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig from transformers.integrations.hub_kernels import ( @@ -401,3 +401,74 @@ def spy_kernelize(model, device=None, mode=None): self.assertTrue(any(m == Mode.TRAINING for m in last_modes)) self.model.eval() self.assertTrue(any(m == Mode.INFERENCE for m in last_modes)) + + +@require_kernels +class TestKernelMappingDeviceFiltering(TestCasePlus): + """Test that kernel mappings correctly filter by current device.""" + + def test_multi_device_mapping_filters_correctly(self): + """ + Test that when a kernel_mapping contains multiple devices (cuda, rocm), + only the current device's kernel is registered. + Regression test for issue where ROCm overwrote CUDA mapping. + """ + kernel_mapping = { + "RMSNorm": { + "cuda": "kernels-community/layer_norm:LlamaRMSNorm", + "rocm": "kernels-community/layer_norm:LlamaRMSNorm", + } + } + + kernel_config = KernelConfig(kernel_mapping) + + # Create a mock model on CUDA device + mock_model = MagicMock() + mock_model.training = False + + # Mock parameter with CUDA device + mock_param = MagicMock() + mock_param.device.type = "cuda" + mock_model.parameters.return_value = iter([mock_param]) + + # Mock named_modules with RMSNorm layer + mock_layer = MagicMock() + mock_layer.kernel_layer_name = "RMSNorm" + mock_model.named_modules.return_value = [("layers.0", mock_layer)] + + # Trigger the mapping creation + kernel_config.create_compatible_mapping(mock_model) + + # Verify results + result_mapping = kernel_config.kernel_mapping + + self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping") + backends = list(result_mapping["RMSNorm"].keys()) + + # Assert only CUDA is present, not ROCm + self.assertIn("cuda", backends, "CUDA backend should be registered") + self.assertNotIn("rocm", backends, "ROCm backend should NOT be registered on CUDA device") + + def test_single_device_mapping_still_works(self): + """ + Test that single-device mappings continue to work as expected. + """ + kernel_mapping = {"RMSNorm": "kernels-community/layer_norm:LlamaRMSNorm"} + + kernel_config = KernelConfig(kernel_mapping) + + # Create a mock model + mock_model = MagicMock() + mock_model.training = False + + mock_param = MagicMock() + mock_param.device.type = "cuda" + mock_model.parameters.return_value = iter([mock_param]) + + mock_layer = MagicMock() + mock_layer.kernel_layer_name = "RMSNorm" + mock_model.named_modules.return_value = [("layers.0", mock_layer)] + kernel_config.create_compatible_mapping(mock_model) + + result_mapping = kernel_config.kernel_mapping + self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping")