diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 4bc3236759..0c1050a92d 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -15,6 +15,7 @@ MappingType, ZeroPointDomain, _choose_qparams_affine_tinygemm, + _choose_qparams_and_quantize_scale_only_hqq, _choose_qparams_and_quantize_scale_only_sinq, _choose_scale_float8, _fake_quantize_affine, @@ -29,6 +30,7 @@ # TODO: remove test for utils? from torchao.quantization.utils import ( _quantize_activation_per_token_absmax, + compute_error, get_block_size, get_group_qparams_symmetric, groupwise_affine_dequantize_tensor_from_qparams, @@ -863,6 +865,42 @@ def test_choose_qparams_and_quantize_scale_only_sinq(self): ).reshape(input.shape) self.assertFalse(torch.isnan(reconstructed).any()) + def test_choose_qparams_and_quantize_scale_only_hqq(self): + """Test HQQ quantization produces valid outputs with correct shapes and ranges.""" + torch.manual_seed(self.SEED) + input = torch.randn(128, 256, dtype=torch.float32) + block_size = [1, 64] + qmin = -(2 ** (4 - 1)) + qmax = 2 ** (4 - 1) - 1 + + qdata, scale = _choose_qparams_and_quantize_scale_only_hqq( + input, + block_size=block_size, + qmin=qmin, + qmax=qmax, + iters=20, + ) + + # Check quantized data shape and dtype + self.assertEqual(qdata.dtype, torch.int32) + self.assertEqual(qdata.shape, input.shape) + self.assertTrue((qdata >= qmin).all() and (qdata <= qmax).all()) + + # Check scale shape and values + num_groups = input.shape[1] // block_size[1] + self.assertEqual(scale.shape, (input.shape[0], num_groups)) + self.assertEqual(scale.dtype, input.dtype) + self.assertTrue((scale > 0).all()) + + # Test reconstruction is possible + scale_expanded = scale.repeat_interleave(block_size[1], dim=1) + reconstructed = qdata.to(input.dtype) * scale_expanded + self.assertFalse(torch.isnan(reconstructed).any()) + self.assertEqual(reconstructed.shape, input.shape) + + error = compute_error(input, reconstructed) + self.assertLess(error, 25) + def test_float8_blockwise_scaling(self): M, K = 512, 1024 hp_tensor = torch.randn(M, K, dtype=torch.float)