File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
torchao/prototype/mx_formats Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change 3838
3939logger = logging .getLogger (__name__ )
4040
41-
4241def get_bits (x : torch .Tensor ) -> str :
4342 bits_per_byte = 8
4443 # Numpy has a nice function to get the string representation of binary.
@@ -628,7 +627,8 @@ def triton_f6_e2m3_to_scaled_bf16(
628627 output = torch .empty (* new_shape , device = x .device , dtype = torch .bfloat16 )
629628
630629 assert x .is_contiguous ()
631- assert x .is_cuda and output .is_cuda
630+ assert x .is_cuda or x .device .type == "xpu"
631+ assert output .is_cuda or output .device .type == "xpu"
632632
633633 n_mx_blocks = x .shape [0 ]
634634 grid = lambda meta : (triton .cdiv (n_mx_blocks , meta ["BLOCK_SIZE_IN" ]),)
@@ -671,7 +671,8 @@ def triton_f6_e3m2_to_scaled_bf16(
671671 output = torch .empty (* new_shape , device = x .device , dtype = torch .bfloat16 )
672672
673673 assert x .is_contiguous ()
674- assert x .is_cuda and output .is_cuda
674+ assert x .is_cuda or x .device .type == "xpu"
675+ assert output .is_cuda or output .device .type == "xpu"
675676
676677 n_mx_blocks = x .numel () // packed_mx_block_size
677678 grid = lambda meta : (triton .cdiv (n_mx_blocks , meta ["BLOCK_SIZE_IN" ]),)
You can’t perform that action at this time.
0 commit comments