Skip to content

Commit 33d0629

Browse files
committed
add test/float8/test_float8_utils.py
1 parent 1272f3c commit 33d0629

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

test/float8/test_float8_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010

1111
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
1212
from torchao.testing.utils import skip_if_rocm
13+
from torchao.utils import get_current_accelerator_device
14+
15+
_DEVICE = get_current_accelerator_device()
1316

1417

1518
# source for notable single-precision cases:
1619
# https://en.wikipedia.org/wiki/Single-precision_floating-point_format
17-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
20+
@unittest.skipIf(not torch.accelerator.is_available(), "GPU not available")
1821
@pytest.mark.parametrize(
1922
"test_case",
2023
[
@@ -38,8 +41,8 @@ def test_round_scale_down_to_power_of_2_valid_inputs(
3841
):
3942
test_case_name, input, expected_result = test_case
4043
input_tensor, expected_tensor = (
41-
torch.tensor(input, dtype=torch.float32).cuda(),
42-
torch.tensor(expected_result, dtype=torch.float32).cuda(),
44+
torch.tensor(input, dtype=torch.float32).to(_DEVICE),
45+
torch.tensor(expected_result, dtype=torch.float32).to(_DEVICE),
4346
)
4447
result = _round_scale_down_to_power_of_2(input_tensor)
4548

0 commit comments

Comments
 (0)