Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/pygpukit/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,13 @@ def __repr__(self) -> str:
)

# Loaders (refactored v0.2.11)
from pygpukit.llm.loader import ( # noqa: E402
# Quantization/Optimization configs (v0.2.18 - Issue #115)
from pygpukit.llm.loader import ( # noqa: E402 # noqa: E402
FP8QuantConfig,
ModelOptimizationInfo,
PruningConfig,
QATQuantConfig,
SparsityConfig,
load_gpt2_from_safetensors,
load_llama_from_safetensors,
load_mixtral_from_safetensors,
Expand Down Expand Up @@ -685,4 +691,10 @@ def __repr__(self) -> str:
"DecodeBatch",
"DecodeSpeculative",
"DecodeJacobi",
# Quantization/Optimization configs (v0.2.18 - Issue #115)
"FP8QuantConfig",
"QATQuantConfig",
"PruningConfig",
"SparsityConfig",
"ModelOptimizationInfo",
]
228 changes: 228 additions & 0 deletions src/pygpukit/llm/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,234 @@ def from_config(cls, config: dict) -> FP8QuantConfig | None:
)


# =============================================================================
# QAT/QAD Quantization Support (Issue #115)
# =============================================================================


@dataclass
class QATQuantConfig:
"""QAT (Quantization-Aware Training) configuration.

Supports models trained with:
- NVIDIA TensorRT Model Optimizer
- HuggingFace Optimum
- PyTorch Quantization

Reference:
- https://nvidia.github.io/TensorRT-Model-Optimizer/
- https://developer.nvidia.com/blog/top-5-ai-model-optimization-techniques-for-faster-smarter-inference/
"""

quant_method: str # "qat", "modelopt", "nvfp4", etc.
quant_algo: str # "FP8", "INT8", "NVFP4", "W8A8", etc.
group_size: int # Block/group size for quantization
kv_cache_quant_algo: str | None # KV cache quantization (optional)
exclude_modules: list[str] # Modules to skip quantization
producer: str | None # Tool that produced the checkpoint (e.g., "modelopt")
producer_version: str | None # Version of the producer tool

@classmethod
def from_config(cls, config: dict) -> QATQuantConfig | None:
"""Parse QAT config from HF config.json or hf_quant_config.json."""
# Check for TensorRT Model Optimizer format (hf_quant_config.json style)
if "producer" in config and "quantization" in config:
producer_info = config.get("producer", {})
quant_info = config.get("quantization", {})
return cls(
quant_method="modelopt",
quant_algo=quant_info.get("quant_algo", "unknown"),
group_size=quant_info.get("group_size", 128),
kv_cache_quant_algo=quant_info.get("kv_cache_quant_algo"),
exclude_modules=quant_info.get("exclude_modules", []),
producer=producer_info.get("name"),
producer_version=producer_info.get("version"),
)

# Check for HF quantization_config with QAT method
qc = config.get("quantization_config")
if qc is None:
return None

quant_method = qc.get("quant_method", "")
# QAT methods: "qat", "awq", "gptq", etc. (exclude "fp8" which is handled separately)
qat_methods = {"qat", "awq", "gptq", "bnb", "modelopt"}
if quant_method not in qat_methods:
return None

return cls(
quant_method=quant_method,
quant_algo=qc.get("quant_algo", qc.get("bits", "unknown")),
group_size=qc.get("group_size", qc.get("block_size", 128)),
kv_cache_quant_algo=qc.get("kv_cache_quant_algo"),
exclude_modules=qc.get("modules_to_not_convert", []),
producer=None,
producer_version=None,
)


# =============================================================================
# Pruning Support (Issue #115)
# =============================================================================


@dataclass
class PruningConfig:
"""Pruning configuration for structurally smaller models.

Supports models pruned with:
- NVIDIA TensorRT Model Optimizer
- HuggingFace nn_pruning
- Neural Compressor

Reference:
- https://github.com/huggingface/nn_pruning
- https://github.com/NVIDIA/TensorRT-Model-Optimizer
"""

pruning_method: str # "magnitude", "movement", "structured", "unstructured"
sparsity: float # Target sparsity (0.0 to 1.0)
pruned_heads: dict[int, list[int]] | None # Layer -> pruned head indices
is_structured: bool # True if structured pruning (removes entire heads/neurons)

@classmethod
def from_config(cls, config: dict) -> PruningConfig | None:
"""Parse pruning config from HF config.json."""
# Check for pruned_heads (HuggingFace standard)
pruned_heads = config.get("pruned_heads")
if pruned_heads:
# Convert string keys to int if needed
if isinstance(pruned_heads, dict):
pruned_heads = {int(k): v for k, v in pruned_heads.items()}
return cls(
pruning_method="structured",
sparsity=0.0, # Unknown from config alone
pruned_heads=pruned_heads,
is_structured=True,
)

# Check for pruning_config section
pc = config.get("pruning_config")
if pc is None:
return None

return cls(
pruning_method=pc.get("pruning_type", pc.get("method", "unknown")),
sparsity=pc.get("target_sparsity", pc.get("sparsity", 0.0)),
pruned_heads=pc.get("pruned_heads"),
is_structured=pc.get("is_structured", pc.get("structured", False)),
)


# =============================================================================
# Sparsity Pattern Support (Issue #115)
# =============================================================================


@dataclass
class SparsityConfig:
"""Sparsity pattern configuration for sparse tensor operations.

Supports:
- 2:4 structured sparsity (Ampere+)
- Block sparsity patterns
- Custom sparsity masks

Reference:
- https://developer.nvidia.com/blog/accelerating-inference-with-sparsity-using-ampere-and-tensorrt/
"""

pattern: str # "2:4", "4:8", "block", "unstructured"
block_size: tuple[int, int] | None # For block sparsity
density: float # Non-zero ratio (1 - sparsity)

@classmethod
def from_config(cls, config: dict) -> SparsityConfig | None:
"""Parse sparsity config from HF config.json."""
sc = config.get("sparsity_config")
if sc is None:
# Check for sparsity in quantization_config
qc = config.get("quantization_config", {})
sparsity_pattern = qc.get("sparsity_pattern")
if sparsity_pattern:
return cls(
pattern=sparsity_pattern,
block_size=None,
density=1.0 - qc.get("sparsity", 0.5),
)
return None

pattern = sc.get("pattern", sc.get("sparsity_pattern", "unknown"))
block_size = sc.get("block_size")
if block_size and isinstance(block_size, list):
block_size = tuple(block_size)

return cls(
pattern=pattern,
block_size=block_size,
density=sc.get("density", 1.0 - sc.get("sparsity", 0.0)),
)

def is_2_4_sparse(self) -> bool:
"""Check if this is 2:4 structured sparsity (Ampere+ TensorCore)."""
return self.pattern == "2:4"


# =============================================================================
# Model Optimization Info (Issue #115)
# =============================================================================


@dataclass
class ModelOptimizationInfo:
"""Combined optimization information for a model.

Aggregates all optimization techniques applied to the model:
- Quantization (FP8, QAT, etc.)
- Pruning (structured, unstructured)
- Sparsity (2:4, block)
"""

fp8_config: FP8QuantConfig | None
qat_config: QATQuantConfig | None
pruning_config: PruningConfig | None
sparsity_config: SparsityConfig | None

@classmethod
def from_config(cls, config: dict) -> ModelOptimizationInfo:
"""Parse all optimization configs from config.json."""
return cls(
fp8_config=FP8QuantConfig.from_config(config),
qat_config=QATQuantConfig.from_config(config),
pruning_config=PruningConfig.from_config(config),
sparsity_config=SparsityConfig.from_config(config),
)

def has_any_optimization(self) -> bool:
"""Check if any optimization is applied."""
return any(
[
self.fp8_config,
self.qat_config,
self.pruning_config,
self.sparsity_config,
]
)

def summary(self) -> str:
"""Return a summary string of optimizations."""
parts = []
if self.fp8_config:
parts.append(f"FP8({self.fp8_config.fmt})")
if self.qat_config:
parts.append(f"QAT({self.qat_config.quant_algo})")
if self.pruning_config:
parts.append(f"Pruned({self.pruning_config.pruning_method})")
if self.sparsity_config:
parts.append(f"Sparse({self.sparsity_config.pattern})")
return ", ".join(parts) if parts else "None"


# FP8 E4M3 to float32 lookup table (256 entries)
# Format: 1 sign bit, 4 exponent bits, 3 mantissa bits
# Special values: NaN (0x7F/0xFF), no infinity
Expand Down