Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion test/prototype/test_parq.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@
torch_version_at_least,
)

_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
_DEVICE = "cuda"
elif torch.xpu.is_available():
_DEVICE = "xpu"
else:
_DEVICE = "cpu"


class M(nn.Module):
Expand Down
28 changes: 19 additions & 9 deletions test/prototype/test_quantized_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
quantize_int8_rowwise,
)
from torchao.quantization.quant_api import quantize_
from torchao.utils import get_current_accelerator_device

if common_utils.SEED is None:
common_utils.SEED = 1234

_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_DEVICE = get_current_accelerator_device()
_DEVICES = ["cpu"] + ([_DEVICE] if torch.accelerator.is_available() else [])


def _reset():
Expand Down Expand Up @@ -182,12 +184,14 @@ def test_int8_weight_only_training(self, compile, device):
],
)
@parametrize("module_swap", [False, True])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not torch.accelerator.is_available(), reason="GPU not available"
)
def test_int8_mixed_precision_training(self, compile, config, module_swap):
_reset()
bsize = 64
embed_dim = 64
device = "cuda"
device = _DEVICE

linear = nn.Linear(embed_dim, embed_dim, device=device)
linear_int8mp = copy.deepcopy(linear)
Expand Down Expand Up @@ -221,7 +225,9 @@ def snr(ref, actual):

@pytest.mark.skip("Flaky on CI")
@parametrize("compile", [False, True])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not torch.accelerator.is_available(), reason="GPU not available"
)
def test_bitnet_training(self, compile):
# reference implementation
# https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
Expand All @@ -246,7 +252,7 @@ def forward(self, x):
_reset()
bsize = 4
embed_dim = 32
device = "cuda"
device = _DEVICE

# only use 1 matmul shape to reduce triton autotune time
model_ref = nn.Sequential(
Expand Down Expand Up @@ -342,7 +348,7 @@ def _run_subtest(self, args):
dropout_p=0,
)
torch.manual_seed(42)
base_model = Transformer(model_args).cuda()
base_model = Transformer(model_args).to(_DEVICE)
fsdp_model = copy.deepcopy(base_model)

quantize_(base_model.layers, quantize_fn)
Expand All @@ -362,7 +368,7 @@ def _run_subtest(self, args):

torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(5):
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device=_DEVICE)
fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = fsdp_model(inp).sum()
fsdp_loss.backward()
Expand All @@ -387,14 +393,18 @@ def _run_subtest(self, args):
)

@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not torch.accelerator.is_available(), reason="GPU not available"
)
def test_precompute_bitnet_scale(self):
from torchao.prototype.quantized_training.bitnet import (
get_bitnet_scale,
precompute_bitnet_scale_for_fsdp,
)

model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).cuda()
model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).to(
_DEVICE
)
model_fsdp = copy.deepcopy(model)
quantize_(model_fsdp, bitnet_training())
fully_shard(model_fsdp)
Expand Down
Loading