diff --git a/src/transformers/utils/kernel_config.py b/src/transformers/utils/kernel_config.py index f6400053a434..4dfc81577e44 100644 --- a/src/transformers/utils/kernel_config.py +++ b/src/transformers/utils/kernel_config.py @@ -61,14 +61,20 @@ def add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping): raise ValueError(f"Only cuda, rocm, xpu and npu devices supported, got: {device}") repo_layer_name = repo_name.split(":")[1] repo_id = repo_name.split(":")[0] - compatible_mapping[layer_name] = { - device: { - mode: LayerRepository( - repo_id=repo_id, - layer_name=repo_layer_name, - ) - } - } + + # Initialise layer_name entry if it doesn't exist + if layer_name not in compatible_mapping: + compatible_mapping[layer_name] = {} + + # Initialise device entry if it doesn't exist + if device not in compatible_mapping[layer_name]: + compatible_mapping[layer_name][device] = {} + + # Add the mode entry (this can overwrite if mode already exists, which is fine) + compatible_mapping[layer_name][device][mode] = LayerRepository( + repo_id=repo_id, + layer_name=repo_layer_name, + ) class KernelConfig(PushToHubMixin): diff --git a/tests/kernels/test_kernels.py b/tests/kernels/test_kernels.py index 6311629ac4f2..606d3a0b89cd 100644 --- a/tests/kernels/test_kernels.py +++ b/tests/kernels/test_kernels.py @@ -38,6 +38,7 @@ torch_device, ) from transformers.utils.import_utils import is_kernels_available +from transformers.utils.kernel_config import add_to_mapping if is_kernels_available(): @@ -313,6 +314,63 @@ def fake_get_kernel(repo_id, revision=None, version=None, user_agent=None): HUB[name] = original_entry _KERNEL_MODULE_MAPPING.pop(name, None) + def test_add_to_mapping_multiple_devices(self): + """Test that add_to_mapping preserves multiple devices for the same layer_name.""" # noqa: E501 + compatible_mapping = {} + + # Add cuda device + add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping) + + # Add rocm device - should NOT overwrite cuda + add_to_mapping("RMSNorm", "rocm", "repo:layer", Mode.INFERENCE, compatible_mapping) + + # Verify both devices exist + self.assertIn("cuda", compatible_mapping["RMSNorm"]) + self.assertIn("rocm", compatible_mapping["RMSNorm"]) + + # Verify the structure is correct + self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["cuda"]) + self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["rocm"]) + + # Verify LayerRepository objects are created correctly + cuda_repo = compatible_mapping["RMSNorm"]["cuda"][Mode.INFERENCE] + rocm_repo = compatible_mapping["RMSNorm"]["rocm"][Mode.INFERENCE] + self.assertEqual(cuda_repo._repo_id, "repo") + self.assertEqual(cuda_repo.layer_name, "layer") + self.assertEqual(rocm_repo._repo_id, "repo") + self.assertEqual(rocm_repo.layer_name, "layer") + + def test_add_to_mapping_single_device(self): + """Test that add_to_mapping works correctly with a single device.""" + compatible_mapping = {} + + # Add single device + add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping) + + # Verify structure + self.assertIn("RMSNorm", compatible_mapping) + self.assertIn("cuda", compatible_mapping["RMSNorm"]) + self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["cuda"]) + + # Verify LayerRepository object + repo = compatible_mapping["RMSNorm"]["cuda"][Mode.INFERENCE] + self.assertEqual(repo._repo_id, "repo") + self.assertEqual(repo.layer_name, "layer") + + def test_add_to_mapping_multiple_modes(self): + """Test that add_to_mapping can handle multiple modes for the same device.""" # noqa: E501 + compatible_mapping = {} + + # Add inference mode + add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping) + + # Add training mode - should NOT overwrite inference + add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.TRAINING, compatible_mapping) + + # Verify both modes exist + self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["cuda"]) + self.assertIn(Mode.TRAINING, compatible_mapping["RMSNorm"]["cuda"]) + @require_kernels class TestAttentionKernelRegistration(TestCasePlus):