From 8342dbdcdc196640db1124dcba111619167abf28 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Tue, 30 Dec 2025 11:39:47 +0900 Subject: [PATCH] feat(llm): add QAT/Pruning/Sparsity model config support (#115) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for loading models optimized with QAT, pruning, and sparsity: New config classes in loader.py: - QATQuantConfig: Parse QAT/QAD configs from TensorRT Model Optimizer and HuggingFace formats (AWQ, GPTQ, BNB, etc.) - PruningConfig: Detect structured pruning (pruned_heads) and unstructured pruning configs - SparsityConfig: Support 2:4 structured sparsity patterns for Ampere+ TensorCores - ModelOptimizationInfo: Aggregate all optimization info with has_any_optimization() and summary() helpers Supported formats: - TensorRT Model Optimizer hf_quant_config.json (NVFP4, FP8, INT8) - HuggingFace quantization_config (AWQ, GPTQ, BNB) - HuggingFace pruned_heads (structured pruning) - 2:4 sparsity pattern for sparse TensorCore ops Reference: - https://nvidia.github.io/TensorRT-Model-Optimizer/ - https://github.com/huggingface/nn_pruning 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/pygpukit/llm/__init__.py | 14 ++- src/pygpukit/llm/loader.py | 228 +++++++++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+), 1 deletion(-) diff --git a/src/pygpukit/llm/__init__.py b/src/pygpukit/llm/__init__.py index 06b7fa3..88b0e51 100644 --- a/src/pygpukit/llm/__init__.py +++ b/src/pygpukit/llm/__init__.py @@ -585,7 +585,13 @@ def __repr__(self) -> str: ) # Loaders (refactored v0.2.11) -from pygpukit.llm.loader import ( # noqa: E402 +# Quantization/Optimization configs (v0.2.18 - Issue #115) +from pygpukit.llm.loader import ( # noqa: E402 # noqa: E402 + FP8QuantConfig, + ModelOptimizationInfo, + PruningConfig, + QATQuantConfig, + SparsityConfig, load_gpt2_from_safetensors, load_llama_from_safetensors, load_mixtral_from_safetensors, @@ -685,4 +691,10 @@ def __repr__(self) -> str: "DecodeBatch", "DecodeSpeculative", "DecodeJacobi", + # Quantization/Optimization configs (v0.2.18 - Issue #115) + "FP8QuantConfig", + "QATQuantConfig", + "PruningConfig", + "SparsityConfig", + "ModelOptimizationInfo", ] diff --git a/src/pygpukit/llm/loader.py b/src/pygpukit/llm/loader.py index b71d856..eb47246 100644 --- a/src/pygpukit/llm/loader.py +++ b/src/pygpukit/llm/loader.py @@ -75,6 +75,234 @@ def from_config(cls, config: dict) -> FP8QuantConfig | None: ) +# ============================================================================= +# QAT/QAD Quantization Support (Issue #115) +# ============================================================================= + + +@dataclass +class QATQuantConfig: + """QAT (Quantization-Aware Training) configuration. + + Supports models trained with: + - NVIDIA TensorRT Model Optimizer + - HuggingFace Optimum + - PyTorch Quantization + + Reference: + - https://nvidia.github.io/TensorRT-Model-Optimizer/ + - https://developer.nvidia.com/blog/top-5-ai-model-optimization-techniques-for-faster-smarter-inference/ + """ + + quant_method: str # "qat", "modelopt", "nvfp4", etc. + quant_algo: str # "FP8", "INT8", "NVFP4", "W8A8", etc. + group_size: int # Block/group size for quantization + kv_cache_quant_algo: str | None # KV cache quantization (optional) + exclude_modules: list[str] # Modules to skip quantization + producer: str | None # Tool that produced the checkpoint (e.g., "modelopt") + producer_version: str | None # Version of the producer tool + + @classmethod + def from_config(cls, config: dict) -> QATQuantConfig | None: + """Parse QAT config from HF config.json or hf_quant_config.json.""" + # Check for TensorRT Model Optimizer format (hf_quant_config.json style) + if "producer" in config and "quantization" in config: + producer_info = config.get("producer", {}) + quant_info = config.get("quantization", {}) + return cls( + quant_method="modelopt", + quant_algo=quant_info.get("quant_algo", "unknown"), + group_size=quant_info.get("group_size", 128), + kv_cache_quant_algo=quant_info.get("kv_cache_quant_algo"), + exclude_modules=quant_info.get("exclude_modules", []), + producer=producer_info.get("name"), + producer_version=producer_info.get("version"), + ) + + # Check for HF quantization_config with QAT method + qc = config.get("quantization_config") + if qc is None: + return None + + quant_method = qc.get("quant_method", "") + # QAT methods: "qat", "awq", "gptq", etc. (exclude "fp8" which is handled separately) + qat_methods = {"qat", "awq", "gptq", "bnb", "modelopt"} + if quant_method not in qat_methods: + return None + + return cls( + quant_method=quant_method, + quant_algo=qc.get("quant_algo", qc.get("bits", "unknown")), + group_size=qc.get("group_size", qc.get("block_size", 128)), + kv_cache_quant_algo=qc.get("kv_cache_quant_algo"), + exclude_modules=qc.get("modules_to_not_convert", []), + producer=None, + producer_version=None, + ) + + +# ============================================================================= +# Pruning Support (Issue #115) +# ============================================================================= + + +@dataclass +class PruningConfig: + """Pruning configuration for structurally smaller models. + + Supports models pruned with: + - NVIDIA TensorRT Model Optimizer + - HuggingFace nn_pruning + - Neural Compressor + + Reference: + - https://github.com/huggingface/nn_pruning + - https://github.com/NVIDIA/TensorRT-Model-Optimizer + """ + + pruning_method: str # "magnitude", "movement", "structured", "unstructured" + sparsity: float # Target sparsity (0.0 to 1.0) + pruned_heads: dict[int, list[int]] | None # Layer -> pruned head indices + is_structured: bool # True if structured pruning (removes entire heads/neurons) + + @classmethod + def from_config(cls, config: dict) -> PruningConfig | None: + """Parse pruning config from HF config.json.""" + # Check for pruned_heads (HuggingFace standard) + pruned_heads = config.get("pruned_heads") + if pruned_heads: + # Convert string keys to int if needed + if isinstance(pruned_heads, dict): + pruned_heads = {int(k): v for k, v in pruned_heads.items()} + return cls( + pruning_method="structured", + sparsity=0.0, # Unknown from config alone + pruned_heads=pruned_heads, + is_structured=True, + ) + + # Check for pruning_config section + pc = config.get("pruning_config") + if pc is None: + return None + + return cls( + pruning_method=pc.get("pruning_type", pc.get("method", "unknown")), + sparsity=pc.get("target_sparsity", pc.get("sparsity", 0.0)), + pruned_heads=pc.get("pruned_heads"), + is_structured=pc.get("is_structured", pc.get("structured", False)), + ) + + +# ============================================================================= +# Sparsity Pattern Support (Issue #115) +# ============================================================================= + + +@dataclass +class SparsityConfig: + """Sparsity pattern configuration for sparse tensor operations. + + Supports: + - 2:4 structured sparsity (Ampere+) + - Block sparsity patterns + - Custom sparsity masks + + Reference: + - https://developer.nvidia.com/blog/accelerating-inference-with-sparsity-using-ampere-and-tensorrt/ + """ + + pattern: str # "2:4", "4:8", "block", "unstructured" + block_size: tuple[int, int] | None # For block sparsity + density: float # Non-zero ratio (1 - sparsity) + + @classmethod + def from_config(cls, config: dict) -> SparsityConfig | None: + """Parse sparsity config from HF config.json.""" + sc = config.get("sparsity_config") + if sc is None: + # Check for sparsity in quantization_config + qc = config.get("quantization_config", {}) + sparsity_pattern = qc.get("sparsity_pattern") + if sparsity_pattern: + return cls( + pattern=sparsity_pattern, + block_size=None, + density=1.0 - qc.get("sparsity", 0.5), + ) + return None + + pattern = sc.get("pattern", sc.get("sparsity_pattern", "unknown")) + block_size = sc.get("block_size") + if block_size and isinstance(block_size, list): + block_size = tuple(block_size) + + return cls( + pattern=pattern, + block_size=block_size, + density=sc.get("density", 1.0 - sc.get("sparsity", 0.0)), + ) + + def is_2_4_sparse(self) -> bool: + """Check if this is 2:4 structured sparsity (Ampere+ TensorCore).""" + return self.pattern == "2:4" + + +# ============================================================================= +# Model Optimization Info (Issue #115) +# ============================================================================= + + +@dataclass +class ModelOptimizationInfo: + """Combined optimization information for a model. + + Aggregates all optimization techniques applied to the model: + - Quantization (FP8, QAT, etc.) + - Pruning (structured, unstructured) + - Sparsity (2:4, block) + """ + + fp8_config: FP8QuantConfig | None + qat_config: QATQuantConfig | None + pruning_config: PruningConfig | None + sparsity_config: SparsityConfig | None + + @classmethod + def from_config(cls, config: dict) -> ModelOptimizationInfo: + """Parse all optimization configs from config.json.""" + return cls( + fp8_config=FP8QuantConfig.from_config(config), + qat_config=QATQuantConfig.from_config(config), + pruning_config=PruningConfig.from_config(config), + sparsity_config=SparsityConfig.from_config(config), + ) + + def has_any_optimization(self) -> bool: + """Check if any optimization is applied.""" + return any( + [ + self.fp8_config, + self.qat_config, + self.pruning_config, + self.sparsity_config, + ] + ) + + def summary(self) -> str: + """Return a summary string of optimizations.""" + parts = [] + if self.fp8_config: + parts.append(f"FP8({self.fp8_config.fmt})") + if self.qat_config: + parts.append(f"QAT({self.qat_config.quant_algo})") + if self.pruning_config: + parts.append(f"Pruned({self.pruning_config.pruning_method})") + if self.sparsity_config: + parts.append(f"Sparse({self.sparsity_config.pattern})") + return ", ".join(parts) if parts else "None" + + # FP8 E4M3 to float32 lookup table (256 entries) # Format: 1 sign bit, 4 exponent bits, 3 mantissa bits # Special values: NaN (0x7F/0xFF), no infinity