Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions src/transformers/utils/kernel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,20 @@ def add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping):
raise ValueError(f"Only cuda, rocm, xpu and npu devices supported, got: {device}")
repo_layer_name = repo_name.split(":")[1]
repo_id = repo_name.split(":")[0]
compatible_mapping[layer_name] = {
device: {
mode: LayerRepository(
repo_id=repo_id,
layer_name=repo_layer_name,
)
}
}

# Initialise layer_name entry if it doesn't exist
if layer_name not in compatible_mapping:
compatible_mapping[layer_name] = {}

# Initialise device entry if it doesn't exist
if device not in compatible_mapping[layer_name]:
compatible_mapping[layer_name][device] = {}

# Add the mode entry (this can overwrite if mode already exists, which is fine)
compatible_mapping[layer_name][device][mode] = LayerRepository(
repo_id=repo_id,
layer_name=repo_layer_name,
)


class KernelConfig(PushToHubMixin):
Expand Down
58 changes: 58 additions & 0 deletions tests/kernels/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
torch_device,
)
from transformers.utils.import_utils import is_kernels_available
from transformers.utils.kernel_config import add_to_mapping


if is_kernels_available():
Expand Down Expand Up @@ -313,6 +314,63 @@ def fake_get_kernel(repo_id, revision=None, version=None, user_agent=None):
HUB[name] = original_entry
_KERNEL_MODULE_MAPPING.pop(name, None)

def test_add_to_mapping_multiple_devices(self):
"""Test that add_to_mapping preserves multiple devices for the same layer_name.""" # noqa: E501
compatible_mapping = {}

# Add cuda device
add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping)

# Add rocm device - should NOT overwrite cuda
add_to_mapping("RMSNorm", "rocm", "repo:layer", Mode.INFERENCE, compatible_mapping)

# Verify both devices exist
self.assertIn("cuda", compatible_mapping["RMSNorm"])
self.assertIn("rocm", compatible_mapping["RMSNorm"])

# Verify the structure is correct
self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["cuda"])
self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["rocm"])

# Verify LayerRepository objects are created correctly
cuda_repo = compatible_mapping["RMSNorm"]["cuda"][Mode.INFERENCE]
rocm_repo = compatible_mapping["RMSNorm"]["rocm"][Mode.INFERENCE]
self.assertEqual(cuda_repo._repo_id, "repo")
self.assertEqual(cuda_repo.layer_name, "layer")
self.assertEqual(rocm_repo._repo_id, "repo")
self.assertEqual(rocm_repo.layer_name, "layer")

def test_add_to_mapping_single_device(self):
"""Test that add_to_mapping works correctly with a single device."""
compatible_mapping = {}

# Add single device
add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping)

# Verify structure
self.assertIn("RMSNorm", compatible_mapping)
self.assertIn("cuda", compatible_mapping["RMSNorm"])
self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["cuda"])

# Verify LayerRepository object
repo = compatible_mapping["RMSNorm"]["cuda"][Mode.INFERENCE]
self.assertEqual(repo._repo_id, "repo")
self.assertEqual(repo.layer_name, "layer")

def test_add_to_mapping_multiple_modes(self):
"""Test that add_to_mapping can handle multiple modes for the same device.""" # noqa: E501
compatible_mapping = {}

# Add inference mode
add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.INFERENCE, compatible_mapping)

# Add training mode - should NOT overwrite inference
add_to_mapping("RMSNorm", "cuda", "repo:layer", Mode.TRAINING, compatible_mapping)

# Verify both modes exist
self.assertIn(Mode.INFERENCE, compatible_mapping["RMSNorm"]["cuda"])
self.assertIn(Mode.TRAINING, compatible_mapping["RMSNorm"]["cuda"])


@require_kernels
class TestAttentionKernelRegistration(TestCasePlus):
Expand Down