Skip to content

Commit c2ecff1

Browse files
committed
test/prototype/mx_formats/test_mx_tensor.py
1 parent 69ad8ff commit c2ecff1

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838

3939
logger = logging.getLogger(__name__)
4040

41-
4241
def get_bits(x: torch.Tensor) -> str:
4342
bits_per_byte = 8
4443
# Numpy has a nice function to get the string representation of binary.
@@ -628,7 +627,8 @@ def triton_f6_e2m3_to_scaled_bf16(
628627
output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16)
629628

630629
assert x.is_contiguous()
631-
assert x.is_cuda and output.is_cuda
630+
assert x.is_cuda or x.device.type == "xpu"
631+
assert output.is_cuda or output.device.type == "xpu"
632632

633633
n_mx_blocks = x.shape[0]
634634
grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),)
@@ -671,7 +671,8 @@ def triton_f6_e3m2_to_scaled_bf16(
671671
output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16)
672672

673673
assert x.is_contiguous()
674-
assert x.is_cuda and output.is_cuda
674+
assert x.is_cuda or x.device.type == "xpu"
675+
assert output.is_cuda or output.device.type == "xpu"
675676

676677
n_mx_blocks = x.numel() // packed_mx_block_size
677678
grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),)

0 commit comments

Comments
 (0)