From 3cefa116699cb080caa1729f57eeb62cf79e8510 Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 30 Nov 2025 15:54:32 +0900 Subject: [PATCH 1/4] build PERF for HQQ vs. Affine --- benchmarks/benchmark_intx.py | 218 +++++++++++++++++++++ test/quantization/test_quant_primitives.py | 34 ++++ 2 files changed, 252 insertions(+) create mode 100644 benchmarks/benchmark_intx.py diff --git a/benchmarks/benchmark_intx.py b/benchmarks/benchmark_intx.py new file mode 100644 index 0000000000..4fdf80e351 --- /dev/null +++ b/benchmarks/benchmark_intx.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +"""Benchmark calibration-free quantization methods: INT8-INT4, INT8-INT4-HQQ + +HQQ doesn't require calibration flow and can be applied directly without element-wise operations. + +Usage: + python benchmarks/benchmark_intx_methods.py --model_id meta-llama/Llama-3.1-8B --limit 100 +""" + +import argparse +import csv +import gc +import time +from dataclasses import dataclass + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from torchao._models._eval import TransformerEvalWrapper +from torchao.quantization import ( + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + ModuleFqnToConfig, + quantize_, +) +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.utils import benchmark_model, get_model_size_in_bytes + + +@dataclass +class Result: + method: str + size_gb: float + comp_ratio: float + quant_time_s: float + fwd_ms: float + tok_per_s: float + peak_mem_gb: float + accuracy: dict = None + + +def get_config(method: str): + """Get quantization config for method.""" + configs = { + "INT8-INT4": ( + IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0)), + Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, weight_granularity=PerGroup(32) + ), + ), + "INT8-INT4-HQQ": ( + IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), + intx_choose_qparams_algorithm="hqq_scale_only", + ), + Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), + intx_choose_qparams_algorithm="hqq_scale_only", + ), + ), + } + emb_cfg, lin_cfg = configs[method] + return ModuleFqnToConfig({"_default": lin_cfg, "model.embed_tokens": emb_cfg}) + + +def benchmark_method( + model_id, method, baseline_size=None, tasks=None, limit=None, device="cuda" +): + """Benchmark a single method.""" + print(f"\n{'=' * 60}\n{method}\n{'=' * 60}") + + # Ensure clean CUDA state before loading + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + # Load and quantize + model = AutoModelForCausalLM.from_pretrained( + model_id, device_map="auto", torch_dtype=torch.bfloat16 + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + torch.cuda.reset_peak_memory_stats() + t0 = time.time() + quantize_(model, get_config(method), filter_fn=None) + quant_time = time.time() - t0 + + # Metrics + size = get_model_size_in_bytes(model) + size_gb = size / 1e9 + comp_ratio = baseline_size / size if baseline_size else 1.0 + + # Benchmark forward pass + inputs = torch.randint(0, 32000, (1, 512), device=device) + for _ in range(3): # warmup + model(inputs) + fwd_ms = benchmark_model(model, 10, (inputs,), device_type=device) + + # Benchmark generation + torch.cuda.synchronize() + t0 = time.perf_counter() + with torch.no_grad(): + out = model.generate(inputs, max_new_tokens=100, do_sample=False) + torch.cuda.synchronize() + gen_time = time.perf_counter() - t0 + tok_per_s = (out.shape[1] - inputs.shape[1]) / gen_time + + peak_mem = torch.cuda.max_memory_allocated() / 1e9 + + # Eval + accuracy = None + eval_results = TransformerEvalWrapper( + model, tokenizer, 2048, device=device + ).run_eval(tasks, limit) + results_dict = eval_results.get("results", {}) + first_task = list(results_dict.keys())[0] + accuracy = results_dict[first_task].get("exact_match,flexible-extract", None) + + result = Result( + method, size_gb, comp_ratio, quant_time, fwd_ms, tok_per_s, peak_mem, accuracy + ) + + # Cleanup with proper synchronization + del model + del tokenizer + torch.cuda.synchronize() + torch.cuda.empty_cache() + + print( + f"Size: {size_gb:.3f} GB ({comp_ratio:.2f}x) | Quant: {quant_time:.1f}s | " + f"Fwd: {fwd_ms:.2f}ms | Throughput: {tok_per_s:.1f} tok/s | Mem: {peak_mem:.2f} GB" + ) + + return result + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", required=True, help="HuggingFace model ID") + parser.add_argument("--methods", nargs="+", default=["INT8-INT4", "INT8-INT4-HQQ"]) + parser.add_argument("--tasks", nargs="+", default=["gsm8k"], help="lm_eval tasks") + parser.add_argument("--limit", type=int, default=50, help="lm_eval limit per task") + parser.add_argument("--output", default="results.csv") + parser.add_argument("--device", default="cuda") + args = parser.parse_args() + print(f"Benchmarking {args.model_id} on {args.methods}") + + # Baseline + print("\nMeasuring baseline...") + model = AutoModelForCausalLM.from_pretrained( + args.model_id, device_map="auto", torch_dtype=torch.bfloat16 + ) + baseline = get_model_size_in_bytes(model) + print(f"Baseline: {baseline / 1e9:.3f} GB") + del model + torch.cuda.empty_cache() + + # Benchmark + results = [] + for method in args.methods: + results.append( + benchmark_method( + args.model_id, method, baseline, args.tasks, args.limit, args.device + ) + ) + torch.cuda.empty_cache() + + # Summary + print(f"\n{'=' * 80}\nSUMMARY\n{'=' * 80}") + print( + f"{'Method':<18} {'Size(GB)':<10} {'Comp':<8} {'Quant(s)':<10} {'Fwd(ms)':<10} {'Tok/s':<10} {'Mem(GB)':<10}" + ) + print("-" * 80) + for r in results: + print( + f"{r.method:<18} {r.size_gb:<10.3f} {r.comp_ratio:<8.2f} {r.quant_time_s:<10.1f} " + f"{r.fwd_ms:<10.2f} {r.tok_per_s:<10.1f} {r.peak_mem_gb:<10.2f} {r.accuracy}" + ) + + # Save CSV + with open(args.output, "w") as f: + writer = csv.writer(f) + writer.writerow( + [ + "method", + "size_gb", + "compression", + "quant_time_s", + "fwd_ms", + "tok_per_s", + "peak_mem_gb", + "accuracy", + ] + ) + for r in results: + writer.writerow( + [ + r.method, + f"{r.size_gb:.2f}", + f"{r.comp_ratio:.2f}", + f"{r.quant_time_s:.1f}", + f"{r.fwd_ms:.2f}", + f"{r.tok_per_s:.1f}", + f"{r.peak_mem_gb:.2f}", + f"{r.accuracy:.4f}", + ] + ) + print(f"\nSaved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 4bc3236759..738fbe4e40 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -15,6 +15,7 @@ MappingType, ZeroPointDomain, _choose_qparams_affine_tinygemm, + _choose_qparams_and_quantize_scale_only_hqq, _choose_qparams_and_quantize_scale_only_sinq, _choose_scale_float8, _fake_quantize_affine, @@ -863,6 +864,39 @@ def test_choose_qparams_and_quantize_scale_only_sinq(self): ).reshape(input.shape) self.assertFalse(torch.isnan(reconstructed).any()) + def test_choose_qparams_and_quantize_scale_only_hqq(self): + """Test HQQ quantization produces valid outputs with correct shapes and ranges.""" + torch.manual_seed(self.SEED) + input = torch.randn(128, 256, dtype=torch.float32) + block_size = [1, 64] + qmin = -(2 ** (4 - 1)) + qmax = 2 ** (4 - 1) - 1 + + qdata, scale = _choose_qparams_and_quantize_scale_only_hqq( + input, + block_size=block_size, + qmin=qmin, + qmax=qmax, + iters=20, + ) + + # Check quantized data shape and dtype + self.assertEqual(qdata.dtype, torch.int32) + self.assertEqual(qdata.shape, input.shape) + self.assertTrue((qdata >= qmin).all() and (qdata <= qmax).all()) + + # Check scale shape and values + num_groups = input.shape[1] // block_size[1] + self.assertEqual(scale.shape, (input.shape[0], num_groups)) + self.assertEqual(scale.dtype, input.dtype) + self.assertTrue((scale > 0).all()) + + # Test reconstruction is possible + scale_expanded = scale.repeat_interleave(block_size[1], dim=1) + reconstructed = qdata.to(input.dtype) * scale_expanded + self.assertFalse(torch.isnan(reconstructed).any()) + self.assertEqual(reconstructed.shape, input.shape) + def test_float8_blockwise_scaling(self): M, K = 512, 1024 hp_tensor = torch.randn(M, K, dtype=torch.float) From 5879030102a5bd1bc174f0ebf958d2fcb3db4bd4 Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 30 Nov 2025 17:35:07 +0900 Subject: [PATCH 2/4] minor fix --- benchmarks/benchmark_intx.py | 18 ++++++++++++------ torchao/_models/_eval.py | 2 +- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/benchmarks/benchmark_intx.py b/benchmarks/benchmark_intx.py index 4fdf80e351..7b48c1cd22 100644 --- a/benchmarks/benchmark_intx.py +++ b/benchmarks/benchmark_intx.py @@ -66,7 +66,12 @@ def get_config(method: str): ), } emb_cfg, lin_cfg = configs[method] - return ModuleFqnToConfig({"_default": lin_cfg, "model.embed_tokens": emb_cfg}) + return ModuleFqnToConfig( + { + "re:.*\.(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)$": lin_cfg, + "model.embed_tokens": emb_cfg, + } + ) def benchmark_method( @@ -75,16 +80,17 @@ def benchmark_method( """Benchmark a single method.""" print(f"\n{'=' * 60}\n{method}\n{'=' * 60}") - # Ensure clean CUDA state before loading + # Clean CUDA torch.cuda.synchronize() gc.collect() torch.cuda.empty_cache() # Load and quantize model = AutoModelForCausalLM.from_pretrained( - model_id, device_map="auto", torch_dtype=torch.bfloat16 + model_id, device_map="auto", dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token torch.cuda.reset_peak_memory_stats() t0 = time.time() @@ -126,7 +132,7 @@ def benchmark_method( method, size_gb, comp_ratio, quant_time, fwd_ms, tok_per_s, peak_mem, accuracy ) - # Cleanup with proper synchronization + # Clean CUDA del model del tokenizer torch.cuda.synchronize() @@ -146,7 +152,7 @@ def main(): parser.add_argument("--methods", nargs="+", default=["INT8-INT4", "INT8-INT4-HQQ"]) parser.add_argument("--tasks", nargs="+", default=["gsm8k"], help="lm_eval tasks") parser.add_argument("--limit", type=int, default=50, help="lm_eval limit per task") - parser.add_argument("--output", default="results.csv") + parser.add_argument("--output", default="benchmarks/benchmark_intx.csv") parser.add_argument("--device", default="cuda") args = parser.parse_args() print(f"Benchmarking {args.model_id} on {args.methods}") @@ -154,7 +160,7 @@ def main(): # Baseline print("\nMeasuring baseline...") model = AutoModelForCausalLM.from_pretrained( - args.model_id, device_map="auto", torch_dtype=torch.bfloat16 + args.model_id, device_map="auto", dtype=torch.bfloat16 ) baseline = get_model_size_in_bytes(model) print(f"Baseline: {baseline / 1e9:.3f} GB") diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index 29582e9b3d..4a0eff4f1b 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -106,7 +106,7 @@ def max_length(self): @property def max_gen_toks(self): - return 50 + return 512 @property def batch_size(self): From 8ed1ac0f47cc67fc677adec5f1b795a929d8987e Mon Sep 17 00:00:00 2001 From: youn17 Date: Wed, 10 Dec 2025 15:45:06 +0900 Subject: [PATCH 3/4] drop (revert) intx benchmark --- benchmarks/benchmark_intx.py | 224 ----------------------------------- torchao/_models/_eval.py | 2 +- 2 files changed, 1 insertion(+), 225 deletions(-) delete mode 100644 benchmarks/benchmark_intx.py diff --git a/benchmarks/benchmark_intx.py b/benchmarks/benchmark_intx.py deleted file mode 100644 index 7b48c1cd22..0000000000 --- a/benchmarks/benchmark_intx.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -"""Benchmark calibration-free quantization methods: INT8-INT4, INT8-INT4-HQQ - -HQQ doesn't require calibration flow and can be applied directly without element-wise operations. - -Usage: - python benchmarks/benchmark_intx_methods.py --model_id meta-llama/Llama-3.1-8B --limit 100 -""" - -import argparse -import csv -import gc -import time -from dataclasses import dataclass - -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -from torchao._models._eval import TransformerEvalWrapper -from torchao.quantization import ( - Int8DynamicActivationIntxWeightConfig, - IntxWeightOnlyConfig, - ModuleFqnToConfig, - quantize_, -) -from torchao.quantization.granularity import PerAxis, PerGroup -from torchao.utils import benchmark_model, get_model_size_in_bytes - - -@dataclass -class Result: - method: str - size_gb: float - comp_ratio: float - quant_time_s: float - fwd_ms: float - tok_per_s: float - peak_mem_gb: float - accuracy: dict = None - - -def get_config(method: str): - """Get quantization config for method.""" - configs = { - "INT8-INT4": ( - IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0)), - Int8DynamicActivationIntxWeightConfig( - weight_dtype=torch.int4, weight_granularity=PerGroup(32) - ), - ), - "INT8-INT4-HQQ": ( - IntxWeightOnlyConfig( - weight_dtype=torch.int8, - granularity=PerAxis(0), - intx_choose_qparams_algorithm="hqq_scale_only", - ), - Int8DynamicActivationIntxWeightConfig( - weight_dtype=torch.int4, - weight_granularity=PerGroup(32), - intx_choose_qparams_algorithm="hqq_scale_only", - ), - ), - } - emb_cfg, lin_cfg = configs[method] - return ModuleFqnToConfig( - { - "re:.*\.(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)$": lin_cfg, - "model.embed_tokens": emb_cfg, - } - ) - - -def benchmark_method( - model_id, method, baseline_size=None, tasks=None, limit=None, device="cuda" -): - """Benchmark a single method.""" - print(f"\n{'=' * 60}\n{method}\n{'=' * 60}") - - # Clean CUDA - torch.cuda.synchronize() - gc.collect() - torch.cuda.empty_cache() - - # Load and quantize - model = AutoModelForCausalLM.from_pretrained( - model_id, device_map="auto", dtype=torch.bfloat16 - ) - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token - - torch.cuda.reset_peak_memory_stats() - t0 = time.time() - quantize_(model, get_config(method), filter_fn=None) - quant_time = time.time() - t0 - - # Metrics - size = get_model_size_in_bytes(model) - size_gb = size / 1e9 - comp_ratio = baseline_size / size if baseline_size else 1.0 - - # Benchmark forward pass - inputs = torch.randint(0, 32000, (1, 512), device=device) - for _ in range(3): # warmup - model(inputs) - fwd_ms = benchmark_model(model, 10, (inputs,), device_type=device) - - # Benchmark generation - torch.cuda.synchronize() - t0 = time.perf_counter() - with torch.no_grad(): - out = model.generate(inputs, max_new_tokens=100, do_sample=False) - torch.cuda.synchronize() - gen_time = time.perf_counter() - t0 - tok_per_s = (out.shape[1] - inputs.shape[1]) / gen_time - - peak_mem = torch.cuda.max_memory_allocated() / 1e9 - - # Eval - accuracy = None - eval_results = TransformerEvalWrapper( - model, tokenizer, 2048, device=device - ).run_eval(tasks, limit) - results_dict = eval_results.get("results", {}) - first_task = list(results_dict.keys())[0] - accuracy = results_dict[first_task].get("exact_match,flexible-extract", None) - - result = Result( - method, size_gb, comp_ratio, quant_time, fwd_ms, tok_per_s, peak_mem, accuracy - ) - - # Clean CUDA - del model - del tokenizer - torch.cuda.synchronize() - torch.cuda.empty_cache() - - print( - f"Size: {size_gb:.3f} GB ({comp_ratio:.2f}x) | Quant: {quant_time:.1f}s | " - f"Fwd: {fwd_ms:.2f}ms | Throughput: {tok_per_s:.1f} tok/s | Mem: {peak_mem:.2f} GB" - ) - - return result - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model_id", required=True, help="HuggingFace model ID") - parser.add_argument("--methods", nargs="+", default=["INT8-INT4", "INT8-INT4-HQQ"]) - parser.add_argument("--tasks", nargs="+", default=["gsm8k"], help="lm_eval tasks") - parser.add_argument("--limit", type=int, default=50, help="lm_eval limit per task") - parser.add_argument("--output", default="benchmarks/benchmark_intx.csv") - parser.add_argument("--device", default="cuda") - args = parser.parse_args() - print(f"Benchmarking {args.model_id} on {args.methods}") - - # Baseline - print("\nMeasuring baseline...") - model = AutoModelForCausalLM.from_pretrained( - args.model_id, device_map="auto", dtype=torch.bfloat16 - ) - baseline = get_model_size_in_bytes(model) - print(f"Baseline: {baseline / 1e9:.3f} GB") - del model - torch.cuda.empty_cache() - - # Benchmark - results = [] - for method in args.methods: - results.append( - benchmark_method( - args.model_id, method, baseline, args.tasks, args.limit, args.device - ) - ) - torch.cuda.empty_cache() - - # Summary - print(f"\n{'=' * 80}\nSUMMARY\n{'=' * 80}") - print( - f"{'Method':<18} {'Size(GB)':<10} {'Comp':<8} {'Quant(s)':<10} {'Fwd(ms)':<10} {'Tok/s':<10} {'Mem(GB)':<10}" - ) - print("-" * 80) - for r in results: - print( - f"{r.method:<18} {r.size_gb:<10.3f} {r.comp_ratio:<8.2f} {r.quant_time_s:<10.1f} " - f"{r.fwd_ms:<10.2f} {r.tok_per_s:<10.1f} {r.peak_mem_gb:<10.2f} {r.accuracy}" - ) - - # Save CSV - with open(args.output, "w") as f: - writer = csv.writer(f) - writer.writerow( - [ - "method", - "size_gb", - "compression", - "quant_time_s", - "fwd_ms", - "tok_per_s", - "peak_mem_gb", - "accuracy", - ] - ) - for r in results: - writer.writerow( - [ - r.method, - f"{r.size_gb:.2f}", - f"{r.comp_ratio:.2f}", - f"{r.quant_time_s:.1f}", - f"{r.fwd_ms:.2f}", - f"{r.tok_per_s:.1f}", - f"{r.peak_mem_gb:.2f}", - f"{r.accuracy:.4f}", - ] - ) - print(f"\nSaved to {args.output}") - - -if __name__ == "__main__": - main() diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index 4a0eff4f1b..29582e9b3d 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -106,7 +106,7 @@ def max_length(self): @property def max_gen_toks(self): - return 512 + return 50 @property def batch_size(self): From f7323e378ac78139099359c2e58ee08e41b2e378 Mon Sep 17 00:00:00 2001 From: youn17 Date: Fri, 12 Dec 2025 00:44:39 +0900 Subject: [PATCH 4/4] add sqnr for sanity check --- test/quantization/test_quant_primitives.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 738fbe4e40..0c1050a92d 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -30,6 +30,7 @@ # TODO: remove test for utils? from torchao.quantization.utils import ( _quantize_activation_per_token_absmax, + compute_error, get_block_size, get_group_qparams_symmetric, groupwise_affine_dequantize_tensor_from_qparams, @@ -897,6 +898,9 @@ def test_choose_qparams_and_quantize_scale_only_hqq(self): self.assertFalse(torch.isnan(reconstructed).any()) self.assertEqual(reconstructed.shape, input.shape) + error = compute_error(input, reconstructed) + self.assertLess(error, 25) + def test_float8_blockwise_scaling(self): M, K = 512, 1024 hp_tensor = torch.randn(M, K, dtype=torch.float)