Skip to content

Commit 40c4f44

Browse files
authored
[xpu][test] Port 2 test/prototype/test_{parq, quantized_training} UT files to intel XPU (#3411)
* add test/prototype/test_parq.py * add test/prototype/test_quantized_training.py * add test/prototype/test_quantized_training.py
1 parent a6dbf45 commit 40c4f44

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

test/prototype/test_parq.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@
5151
torch_version_at_least,
5252
)
5353

54-
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54+
_DEVICE = torch.device(
55+
torch.accelerator.current_accelerator().type
56+
if torch.accelerator.is_available()
57+
else "cpu"
58+
)
5559

5660

5761
class M(nn.Module):

test/prototype/test_quantized_training.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,17 @@
3434
quantize_int8_rowwise,
3535
)
3636
from torchao.quantization.quant_api import quantize_
37+
from torchao.utils import get_current_accelerator_device
3738

3839
if common_utils.SEED is None:
3940
common_utils.SEED = 1234
4041

41-
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
42+
_DEVICES = (
43+
["cpu"]
44+
+ (["cuda"] if torch.cuda.is_available() else [])
45+
+ (["xpu"] if torch.xpu.is_available() else [])
46+
)
47+
_DEVICE = get_current_accelerator_device()
4248

4349

4450
def _reset():
@@ -182,12 +188,14 @@ def test_int8_weight_only_training(self, compile, device):
182188
],
183189
)
184190
@parametrize("module_swap", [False, True])
185-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
191+
@pytest.mark.skipif(
192+
not torch.accelerator.is_available(), reason="GPU not available"
193+
)
186194
def test_int8_mixed_precision_training(self, compile, config, module_swap):
187195
_reset()
188196
bsize = 64
189197
embed_dim = 64
190-
device = "cuda"
198+
device = _DEVICE
191199

192200
linear = nn.Linear(embed_dim, embed_dim, device=device)
193201
linear_int8mp = copy.deepcopy(linear)
@@ -221,7 +229,9 @@ def snr(ref, actual):
221229

222230
@pytest.mark.skip("Flaky on CI")
223231
@parametrize("compile", [False, True])
224-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
232+
@pytest.mark.skipif(
233+
not torch.accelerator.is_available(), reason="GPU not available"
234+
)
225235
def test_bitnet_training(self, compile):
226236
# reference implementation
227237
# https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
@@ -246,7 +256,7 @@ def forward(self, x):
246256
_reset()
247257
bsize = 4
248258
embed_dim = 32
249-
device = "cuda"
259+
device = _DEVICE
250260

251261
# only use 1 matmul shape to reduce triton autotune time
252262
model_ref = nn.Sequential(
@@ -342,7 +352,7 @@ def _run_subtest(self, args):
342352
dropout_p=0,
343353
)
344354
torch.manual_seed(42)
345-
base_model = Transformer(model_args).cuda()
355+
base_model = Transformer(model_args).to(_DEVICE)
346356
fsdp_model = copy.deepcopy(base_model)
347357

348358
quantize_(base_model.layers, quantize_fn)
@@ -362,7 +372,7 @@ def _run_subtest(self, args):
362372

363373
torch.manual_seed(42 + self.rank + 1)
364374
for iter_idx in range(5):
365-
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
375+
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device=_DEVICE)
366376
fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
367377
fsdp_loss = fsdp_model(inp).sum()
368378
fsdp_loss.backward()
@@ -387,14 +397,18 @@ def _run_subtest(self, args):
387397
)
388398

389399
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
390-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
400+
@pytest.mark.skipif(
401+
not torch.accelerator.is_available(), reason="GPU not available"
402+
)
391403
def test_precompute_bitnet_scale(self):
392404
from torchao.prototype.quantized_training.bitnet import (
393405
get_bitnet_scale,
394406
precompute_bitnet_scale_for_fsdp,
395407
)
396408

397-
model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).cuda()
409+
model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).to(
410+
_DEVICE
411+
)
398412
model_fsdp = copy.deepcopy(model)
399413
quantize_(model_fsdp, bitnet_training())
400414
fully_shard(model_fsdp)

0 commit comments

Comments
 (0)