Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions test/float8/test_float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@

from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
from torchao.testing.utils import skip_if_rocm
from torchao.utils import get_current_accelerator_device

_DEVICE = get_current_accelerator_device()


# source for notable single-precision cases:
# https://en.wikipedia.org/wiki/Single-precision_floating-point_format
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "GPU not available")
@pytest.mark.parametrize(
"test_case",
[
Expand All @@ -38,8 +41,8 @@ def test_round_scale_down_to_power_of_2_valid_inputs(
):
test_case_name, input, expected_result = test_case
input_tensor, expected_tensor = (
torch.tensor(input, dtype=torch.float32).cuda(),
torch.tensor(expected_result, dtype=torch.float32).cuda(),
torch.tensor(input, dtype=torch.float32).to(_DEVICE),
torch.tensor(expected_result, dtype=torch.float32).to(_DEVICE),
)
result = _round_scale_down_to_power_of_2(input_tensor)

Expand Down
9 changes: 6 additions & 3 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
ZeroPointDomain,
quantize_,
)
from torchao.testing.utils import skip_if_rocm
from torchao.testing.utils import skip_if_rocm, skip_if_xpu
from torchao.utils import get_current_accelerator_device

cuda_available = torch.cuda.is_available()
cuda_available = torch.accelerator.is_available()
_DEVICE = get_current_accelerator_device()

# Parameters
device = "cuda:0"
device = f"{_DEVICE}:0"
compute_dtype = torch.bfloat16
group_size = 64
mapping_type = MappingType.ASYMMETRIC
Expand Down Expand Up @@ -114,6 +116,7 @@ def test_hqq_plain_5bit(self):
)

@skip_if_rocm("ROCm enablement in progress")
@skip_if_xpu("XPU enablement in progress")
def test_hqq_plain_4bit(self):
self._test_hqq(
dtype=torch.uint4,
Expand Down