From 1ccf13b33da7da5419344d232aae9059b6d6bf9a Mon Sep 17 00:00:00 2001 From: Grancho Date: Fri, 28 Nov 2025 16:21:20 +0000 Subject: [PATCH 1/5] fix(kernel_config): prevent overwriting existing mappings in add_to_mapping Previously, the add_to_mapping function would overwrite existing layer_name and device entries in the compatible_mapping dictionary when adding new entries. This could cause loss of existing kernel mappings. The fix ensures that: - Layer name entries are preserved if they already exist - Device entries within layer names are preserved if they already exist - Only mode entries can be overwritten (which is the intended behaviour) This makes the function more robust when building compatible mappings incrementally, preventing accidental data loss. --- src/transformers/utils/kernel_config.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/transformers/utils/kernel_config.py b/src/transformers/utils/kernel_config.py index f6400053a434..aaffa5016afd 100644 --- a/src/transformers/utils/kernel_config.py +++ b/src/transformers/utils/kernel_config.py @@ -61,14 +61,20 @@ def add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping): raise ValueError(f"Only cuda, rocm, xpu and npu devices supported, got: {device}") repo_layer_name = repo_name.split(":")[1] repo_id = repo_name.split(":")[0] - compatible_mapping[layer_name] = { - device: { - mode: LayerRepository( - repo_id=repo_id, - layer_name=repo_layer_name, - ) - } - } + + # Initialise layer_name entry if it doesn't exist + if layer_name not in compatible_mapping: + compatible_mapping[layer_name] = {} + + # Initialise device entry if it doesn't exist + if device not in compatible_mapping[layer_name]: + compatible_mapping[layer_name][device] = {} + + # Add the mode entry (this can overwrite if mode already exists, which is fine) + compatible_mapping[layer_name][device][mode] = LayerRepository( + repo_id=repo_id, + layer_name=repo_layer_name, + ) class KernelConfig(PushToHubMixin): From c5fb454da80fcc307d71da096d3ee8aa2d2977ad Mon Sep 17 00:00:00 2001 From: Grancho Date: Fri, 28 Nov 2025 16:23:32 +0000 Subject: [PATCH 2/5] test(kernels): add tests for add_to_mapping function Add comprehensive test coverage for the add_to_mapping function to ensure it correctly handles multiple devices and modes without overwriting existing entries in the compatible_mapping dictionary. The new tests verify: - Multiple devices can be added for the same layer_name without overwriting each other (test_add_to_mapping_multiple_devices) - Single device mappings are created correctly (test_add_to_mapping_single_device) - Multiple modes can be added for the same device without overwriting each other (test_add_to_mapping_multiple_modes) These tests complement the fix in kernel_config.py that prevents overwriting existing mappings when building compatible mappings incrementally. --- tests/kernels/test_kernels.py | 58 +++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/kernels/test_kernels.py b/tests/kernels/test_kernels.py index 6311629ac4f2..5b2c18411794 100644 --- a/tests/kernels/test_kernels.py +++ b/tests/kernels/test_kernels.py @@ -27,6 +27,7 @@ lazy_load_kernel, load_and_register_attn_kernel, ) +from transformers.utils.kernel_config import add_to_mapping from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.testing_utils import ( @@ -313,6 +314,63 @@ def fake_get_kernel(repo_id, revision=None, version=None, user_agent=None): HUB[name] = original_entry _KERNEL_MODULE_MAPPING.pop(name, None) + def test_add_to_mapping_multiple_devices(self): + """Test that add_to_mapping preserves multiple devices for the same layer_name.""" + compatible_mapping = {} + + # Add cuda device + add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping) + + # Add rocm device - should NOT overwrite cuda + add_to_mapping("RMSNorm", "rocm", "repo:layer", Mode.INFERENCE, compatible_mapping) + + # Verify both devices exist + self.assertIn("cuda", compatible_mapping["RMSNorm"]) + self.assertIn("rocm", compatible_mapping["RMSNorm"]) + + # Verify the structure is correct + self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["cuda"]) + self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["rocm"]) + + # Verify LayerRepository objects are created correctly + cuda_repo = compatible_mapping["RMSNorm"]["cuda"][Mode.INFERENCE] + rocm_repo = compatible_mapping["RMSNorm"]["rocm"][Mode.INFERENCE] + self.assertEqual(cuda_repo.repo_id, "repo") + self.assertEqual(cuda_repo.layer_name, "layer") + self.assertEqual(rocm_repo.repo_id, "repo") + self.assertEqual(rocm_repo.layer_name, "layer") + + def test_add_to_mapping_single_device(self): + """Test that add_to_mapping works correctly with a single device.""" + compatible_mapping = {} + + # Add single device + add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping) + + # Verify structure + self.assertIn("RMSNorm", compatible_mapping) + self.assertIn("cuda", compatible_mapping["RMSNorm"]) + self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["cuda"]) + + # Verify LayerRepository object + repo = compatible_mapping["RMSNorm"]["cuda"][Mode.INFERENCE] + self.assertEqual(repo.repo_id, "repo") + self.assertEqual(repo.layer_name, "layer") + + def test_add_to_mapping_multiple_modes(self): + """Test that add_to_mapping can handle multiple modes for the same device.""" + compatible_mapping = {} + + # Add inference mode + add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping) + + # Add training mode - should NOT overwrite inference + add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.TRAINING, compatible_mapping) + + # Verify both modes exist + self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["cuda"]) + self.assertIn(Mode.TRAINING, compatible_mapping["RMSNorm"]["cuda"]) + @require_kernels class TestAttentionKernelRegistration(TestCasePlus): From a9ba94aa94079ac72947374c384c00a7823b6943 Mon Sep 17 00:00:00 2001 From: Grancho Date: Fri, 28 Nov 2025 17:01:46 +0000 Subject: [PATCH 3/5] style(test_kernels): format code and fix test assertions Apply code formatting improvements to the add_to_mapping test functions: - Break long function calls across multiple lines to comply with line length guidelines - Add noqa comments to suppress line length warnings for docstrings - Fix test assertions to use the correct private attribute _repo_id instead of the non-existent public repo_id attribute This ensures the tests follow the project's code style guidelines and correctly verify the LayerRepository object properties. --- tests/kernels/test_kernels.py | 52 +++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/tests/kernels/test_kernels.py b/tests/kernels/test_kernels.py index 5b2c18411794..25ca0edf689b 100644 --- a/tests/kernels/test_kernels.py +++ b/tests/kernels/test_kernels.py @@ -315,58 +315,68 @@ def fake_get_kernel(repo_id, revision=None, version=None, user_agent=None): _KERNEL_MODULE_MAPPING.pop(name, None) def test_add_to_mapping_multiple_devices(self): - """Test that add_to_mapping preserves multiple devices for the same layer_name.""" + """Test that add_to_mapping preserves multiple devices for the same layer_name.""" # noqa: E501 compatible_mapping = {} - + # Add cuda device - add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping) - + add_to_mapping( + "RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping + ) + # Add rocm device - should NOT overwrite cuda - add_to_mapping("RMSNorm", "rocm", "repo:layer", Mode.INFERENCE, compatible_mapping) - + add_to_mapping( + "RMSNorm", "rocm", "repo:layer", Mode.INFERENCE, compatible_mapping + ) + # Verify both devices exist self.assertIn("cuda", compatible_mapping["RMSNorm"]) self.assertIn("rocm", compatible_mapping["RMSNorm"]) - + # Verify the structure is correct self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["cuda"]) self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["rocm"]) - + # Verify LayerRepository objects are created correctly cuda_repo = compatible_mapping["RMSNorm"]["cuda"][Mode.INFERENCE] rocm_repo = compatible_mapping["RMSNorm"]["rocm"][Mode.INFERENCE] - self.assertEqual(cuda_repo.repo_id, "repo") + self.assertEqual(cuda_repo._repo_id, "repo") self.assertEqual(cuda_repo.layer_name, "layer") - self.assertEqual(rocm_repo.repo_id, "repo") + self.assertEqual(rocm_repo._repo_id, "repo") self.assertEqual(rocm_repo.layer_name, "layer") def test_add_to_mapping_single_device(self): """Test that add_to_mapping works correctly with a single device.""" compatible_mapping = {} - + # Add single device - add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping) - + add_to_mapping( + "RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping + ) + # Verify structure self.assertIn("RMSNorm", compatible_mapping) self.assertIn("cuda", compatible_mapping["RMSNorm"]) self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["cuda"]) - + # Verify LayerRepository object repo = compatible_mapping["RMSNorm"]["cuda"][Mode.INFERENCE] - self.assertEqual(repo.repo_id, "repo") + self.assertEqual(repo._repo_id, "repo") self.assertEqual(repo.layer_name, "layer") def test_add_to_mapping_multiple_modes(self): - """Test that add_to_mapping can handle multiple modes for the same device.""" + """Test that add_to_mapping can handle multiple modes for the same device.""" # noqa: E501 compatible_mapping = {} - + # Add inference mode - add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping) - + add_to_mapping( + "RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping + ) + # Add training mode - should NOT overwrite inference - add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.TRAINING, compatible_mapping) - + add_to_mapping( + "RMSNorm", "cuda", "repo:layer", Mode.TRAINING, compatible_mapping + ) + # Verify both modes exist self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["cuda"]) self.assertIn(Mode.TRAINING, compatible_mapping["RMSNorm"]["cuda"]) From e0986d01ca7bcc4af97520f3e4f8b8fcb4c3a3b2 Mon Sep 17 00:00:00 2001 From: Grancho Date: Fri, 28 Nov 2025 17:44:19 +0000 Subject: [PATCH 4/5] style(test_kernels): reorganise imports and consolidate function calls Reorganise imports to group transformers.utils imports together by moving add_to_mapping import to be with other utils imports rather than between hub_kernels and masking_utils imports. Consolidate add_to_mapping function calls back to single lines as they fit within the line length limits, improving code readability. --- tests/kernels/test_kernels.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/tests/kernels/test_kernels.py b/tests/kernels/test_kernels.py index 25ca0edf689b..606d3a0b89cd 100644 --- a/tests/kernels/test_kernels.py +++ b/tests/kernels/test_kernels.py @@ -27,7 +27,6 @@ lazy_load_kernel, load_and_register_attn_kernel, ) -from transformers.utils.kernel_config import add_to_mapping from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.testing_utils import ( @@ -39,6 +38,7 @@ torch_device, ) from transformers.utils.import_utils import is_kernels_available +from transformers.utils.kernel_config import add_to_mapping if is_kernels_available(): @@ -319,14 +319,10 @@ def test_add_to_mapping_multiple_devices(self): compatible_mapping = {} # Add cuda device - add_to_mapping( - "RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping - ) + add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping) # Add rocm device - should NOT overwrite cuda - add_to_mapping( - "RMSNorm", "rocm", "repo:layer", Mode.INFERENCE, compatible_mapping - ) + add_to_mapping("RMSNorm", "rocm", "repo:layer", Mode.INFERENCE, compatible_mapping) # Verify both devices exist self.assertIn("cuda", compatible_mapping["RMSNorm"]) @@ -349,9 +345,7 @@ def test_add_to_mapping_single_device(self): compatible_mapping = {} # Add single device - add_to_mapping( - "RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping - ) + add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping) # Verify structure self.assertIn("RMSNorm", compatible_mapping) @@ -368,14 +362,10 @@ def test_add_to_mapping_multiple_modes(self): compatible_mapping = {} # Add inference mode - add_to_mapping( - "RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping - ) + add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping) # Add training mode - should NOT overwrite inference - add_to_mapping( - "RMSNorm", "cuda", "repo:layer", Mode.TRAINING, compatible_mapping - ) + add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.TRAINING, compatible_mapping) # Verify both modes exist self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["cuda"]) From 9548b4c0fd6f2873ea43610d704bb1a8bff6ccb7 Mon Sep 17 00:00:00 2001 From: Grancho Date: Fri, 28 Nov 2025 17:45:04 +0000 Subject: [PATCH 5/5] style(kernel_config): remove trailing whitespace from blank lines Remove trailing whitespace from blank lines in the add_to_mapping function to comply with code style guidelines and ensure consistent formatting. --- src/transformers/utils/kernel_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/kernel_config.py b/src/transformers/utils/kernel_config.py index aaffa5016afd..4dfc81577e44 100644 --- a/src/transformers/utils/kernel_config.py +++ b/src/transformers/utils/kernel_config.py @@ -61,15 +61,15 @@ def add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping): raise ValueError(f"Only cuda, rocm, xpu and npu devices supported, got: {device}") repo_layer_name = repo_name.split(":")[1] repo_id = repo_name.split(":")[0] - + # Initialise layer_name entry if it doesn't exist if layer_name not in compatible_mapping: compatible_mapping[layer_name] = {} - + # Initialise device entry if it doesn't exist if device not in compatible_mapping[layer_name]: compatible_mapping[layer_name][device] = {} - + # Add the mode entry (this can overwrite if mode already exists, which is fine) compatible_mapping[layer_name][device][mode] = LayerRepository( repo_id=repo_id,