Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/transformers/utils/kernel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def create_compatible_mapping(self, model, compile=False):
from kernels import Mode

compatible_mapping = {}
current_device = infer_device(model)
for layer_name, kernel in self.kernel_mapping.items():
# Infer Mode: use Mode.TRAINING if model is training, else use Mode.INFERENCE
mode = Mode.TRAINING if model.training else Mode.INFERENCE
Expand All @@ -216,10 +217,11 @@ def create_compatible_mapping(self, model, compile=False):

if isinstance(kernel, str):
repo_name = kernel
device = infer_device(model)
add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping)
add_to_mapping(layer_name, current_device, repo_name, mode, compatible_mapping)
elif isinstance(kernel, dict):
for device, repo_name in kernel.items():
if device != current_device:
continue
add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping)

self.kernel_mapping = compatible_mapping
115 changes: 114 additions & 1 deletion tests/kernels/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import copy
import types
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig
from transformers.integrations.hub_kernels import (
Expand Down Expand Up @@ -401,3 +401,116 @@ def spy_kernelize(model, device=None, mode=None):
self.assertTrue(any(m == Mode.TRAINING for m in last_modes))
self.model.eval()
self.assertTrue(any(m == Mode.INFERENCE for m in last_modes))


@require_kernels
class TestKernelMappingDeviceFiltering(TestCasePlus):
"""Test that kernel mappings correctly filter by current device."""

def test_multi_device_mapping_filters_correctly(self):
"""
Test that when a kernel_mapping contains multiple devices (cuda, rocm),
only the current device's kernel is registered.
Regression test for issue where ROCm overwrote CUDA mapping.
"""
kernel_mapping = {
"RMSNorm": {
"cuda": "kernels-community/layer_norm:LlamaRMSNorm",
"rocm": "kernels-community/layer_norm:LlamaRMSNorm",
}
}

kernel_config = KernelConfig(kernel_mapping)

# Create a mock model on CUDA device
mock_model = MagicMock()
mock_model.training = False

# Mock parameter with CUDA device
mock_param = MagicMock()
mock_param.device.type = "cuda"
mock_model.parameters.return_value = iter([mock_param])

# Mock named_modules with RMSNorm layer
mock_layer = MagicMock()
mock_layer.kernel_layer_name = "RMSNorm"
mock_model.named_modules.return_value = [("layers.0", mock_layer)]

# Trigger the mapping creation
kernel_config.create_compatible_mapping(mock_model)

# Verify results
result_mapping = kernel_config.kernel_mapping

self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping")
backends = list(result_mapping["RMSNorm"].keys())

# Assert only CUDA is present, not ROCm
self.assertIn("cuda", backends, "CUDA backend should be registered")
self.assertNotIn("rocm", backends, "ROCm backend should NOT be registered on CUDA device")

def test_rocm_device_filters_correctly(self):
"""
Test that ROCm device correctly filters out CUDA kernels.
"""
kernel_mapping = {
"RMSNorm": {
"cuda": "kernels-community/layer_norm:LlamaRMSNorm",
"rocm": "kernels-community/layer_norm:LlamaRMSNorm",
}
}

kernel_config = KernelConfig(kernel_mapping)

# Create a mock model on ROCm device
mock_model = MagicMock()
mock_model.training = False

# Mock parameter with ROCm device
mock_param = MagicMock()
mock_param.device.type = "rocm"
mock_model.parameters.return_value = iter([mock_param])

# Mock named_modules with RMSNorm layer
mock_layer = MagicMock()
mock_layer.kernel_layer_name = "RMSNorm"
mock_model.named_modules.return_value = [("layers.0", mock_layer)]

# Trigger the mapping creation
kernel_config.create_compatible_mapping(mock_model)

# Verify results
result_mapping = kernel_config.kernel_mapping

self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping")
backends = list(result_mapping["RMSNorm"].keys())

# Assert only ROCm is present, not CUDA
self.assertIn("rocm", backends, "ROCm backend should be registered")
self.assertNotIn("cuda", backends, "CUDA backend should NOT be registered on ROCm device")

def test_single_device_mapping_still_works(self):
"""
Test that single-device mappings continue to work as expected.
"""
kernel_mapping = {"RMSNorm": "kernels-community/layer_norm:LlamaRMSNorm"}

kernel_config = KernelConfig(kernel_mapping)

# Create a mock model
mock_model = MagicMock()
mock_model.training = False

mock_param = MagicMock()
mock_param.device.type = "cuda"
mock_model.parameters.return_value = iter([mock_param])

mock_layer = MagicMock()
mock_layer.kernel_layer_name = "RMSNorm"
mock_model.named_modules.return_value = [("layers.0", mock_layer)]

# Should not raise any errors
kernel_config.create_compatible_mapping(mock_model)

result_mapping = kernel_config.kernel_mapping
self.assertIn("RMSNorm", result_mapping, "RMSNorm should be in mapping")