Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 94 additions & 32 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,71 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
else:
logger.info(f"Pytorch transforms applied to model: {self.model_name}")

def _offload_model_weights(self, offload_pt_weights) -> bool:
"""
Clear PyTorch weights after export if offload_pt_weights is set to True

Returns:
bool: True if weights were successfully offloaded, False otherwise
"""
# Check if offloading is enabled and weights are not already offloaded
if offload_pt_weights and not self._is_weights_offloaded:
try:
self.model = self.model.to_empty(device="meta")
self._is_weights_offloaded = True
logger.info("Model weights offloaded to meta device")

gc.collect()
logger.info("PyTorch weights cleared after export")
return True
def _clear_model_weights(self) -> None:
"""Clear PyTorch model weights to reduce memory usage after ONNX export."""
try:
# Clear tensor storage and replace with empty shell
for param in self.model.parameters():
if hasattr(param, "data") and hasattr(param.data, "storage"):
param.data.storage().resize_(0)

for buffer in self.model.buffers():
if hasattr(buffer, "data") and hasattr(buffer.data, "storage"):
buffer.data.storage().resize_(0)

# Clear module dictionaries and hooks
for module in self.model.modules():
if hasattr(module, "_parameters"):
module._parameters.clear()
if hasattr(module, "_buffers"):
module._buffers.clear()

# Clear hooks
for hook_dict in [
getattr(module, "_forward_hooks", {}),
getattr(module, "_forward_pre_hooks", {}),
getattr(module, "_backward_hooks", {}),
getattr(module, "_state_dict_hooks", {}),
getattr(module, "_load_state_dict_pre_hooks", {}),
]:
hook_dict.clear()

# Replace with minimal shell for compatibility
class ModelShell:
def __init__(self, config):
self.config = config
self.qaic_config = None
self.device = torch.device("meta")

def parameters(self):
return iter([])

def named_parameters(self):
return iter([])

def buffers(self):
return iter([])

def named_buffers(self):
return iter([])

def modules(self):
return iter([self])

def state_dict(self):
return {}

def to(self, device):
return self

def eval(self):
return self

config = getattr(self.model, "config", None)
self.model = ModelShell(config)

except Exception as e:
logger.error(f"Failed to offload model weights: {e}")
return False
return False
except Exception as e:
logger.warning(f"Weight clearing failed, continuing: {e}")

def _model_offloaded_check(self) -> None:
"""
Expand Down Expand Up @@ -244,19 +287,32 @@ def _export(

try:
export_kwargs = {} if export_kwargs is None else export_kwargs
torch.onnx.export(
self.model,
(example_inputs,),
str(tmp_onnx_path),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=constants.ONNX_EXPORT_OPSET,
**export_kwargs,
)

with torch.no_grad():
torch.onnx.export(
self.model,
(example_inputs,),
str(tmp_onnx_path),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=constants.ONNX_EXPORT_OPSET,
**export_kwargs,
)
logger.info("PyTorch export successful")

_ = self._offload_model_weights(offload_pt_weights)
# Clear PyTorch weights after successful export to reduce memory usage
if offload_pt_weights:
self._clear_model_weights()
self._is_weights_offloaded = True
logger.info("PyTorch weights cleared after ONNX export")

# Clear temporary references
example_inputs.clear()
input_names.clear()

# Force garbage collection
gc.collect()

model = onnx.load(tmp_onnx_path, load_external_data=False)
transform_kwargs = {
Expand All @@ -283,6 +339,12 @@ def _export(

finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
# Clear external data from memory and cache after all transforms and saving
# Make sure model exists before trying to clean it up
if "model" in locals():
OnnxTransform._cleanup_external_data_and_cache(model)
OnnxTransform._cleanup_memory()
logger.info("Cleanup complete.")

self.onnx_path = onnx_path
return onnx_path
Expand Down
158 changes: 129 additions & 29 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,25 @@
#
# ----------------------------------------------------------------------------

import gc
import logging
from typing import Optional, Tuple

import numpy as np
from onnx import ModelProto, external_data_helper, numpy_helper

from QEfficient.utils.constants import ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: TRANSFORM - spell check

logger = logging.getLogger(__name__)


class OnnxTransform:
"""
OnnxTransform is the base class for graph modifications on exported onnx.
"""

_external_data_loaded_cache = {} # Dict[int, bool]

def __init__(self):
raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.")

Expand All @@ -31,6 +39,68 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
"""
raise NotImplementedError("Use subclasses for ONNX transform")

@classmethod
def _check_external_data_loaded(cls, model: ModelProto) -> bool:
"""
Check if external data is already loaded in the model.

:param model: The ONNX model to check
:returns: True if external data is already loaded, False otherwise
"""
# Use object ID as key instead of the object itself
model_id = id(model)
# Return cached result if available
if model_id in cls._external_data_loaded_cache:
return cls._external_data_loaded_cache[model_id]

# Load the model if not already loaded
for tensor in external_data_helper._get_all_tensors(model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we think of skipping this extra loop for checking whether for all the tensors external data has been loaded or not. The place where we are loading the external data there we can maintain a flag. This flag by default will be set to false and then once all the external data is loaded we can mark it to TRUE. Then in code we may have to just check the flag. or may not need this function if you want to directly use the flag.

# Check if tensor has external data but no raw data loaded
if len(tensor.external_data) > 0 and not tensor.HasField("raw_data"):
cls._external_data_loaded_cache[model_id] = False
return False

cls._external_data_loaded_cache[model_id] = True
return True

@classmethod
def _load_external_data(cls, model: ModelProto, onnx_base_dir: Optional[str] = None):
"""
Performs a bulk load of external data if it's not already loaded.
Updates the cache upon successful load.
"""
model_id = id(model)
if not cls._check_external_data_loaded(model):
logger.info("External data not loaded. Performing bulk load.")
external_data_helper.load_external_data_for_model(model, onnx_base_dir)
cls._external_data_loaded_cache[model_id] = True
else:
logger.info("External data already loaded (or cached). Skipping bulk load.")

@classmethod
def _cleanup_external_data_and_cache(cls, model: ModelProto):
"""
Combines clearing external data from the model and its cache entry.
"""
# Remove the loaded raw data from tensors
for tensor in external_data_helper._get_all_tensors(model):
if tensor.HasField("raw_data"):
tensor.ClearField("raw_data")

# Clear the cache entry for this model using its ID
model_id = id(model)
if model_id in cls._external_data_loaded_cache:
del cls._external_data_loaded_cache[model_id]

logger.info("External data and cache cleaned up.")

@classmethod
def _cleanup_memory(cls):
"""
Force garbage collection to free up memory after tensor processing.
"""
gc.collect()


class FP16ClipTransform(OnnxTransform):
"""
Expand All @@ -42,26 +112,42 @@ def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwar
"""
:param onnx_base_dir: Base directory to load tensors
"""
finfo = np.finfo(np.float16)
fp16_max = finfo.max
fp16_min = finfo.min
transformed = False
try:
# --- FIX: Ensure external data is loaded efficiently BEFORE processing ---
cls._load_external_data(model, onnx_base_dir)

for tensor in external_data_helper._get_all_tensors(model):
nptensor = numpy_helper.to_array(tensor, onnx_base_dir)
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)):
neg_inf_mask = np.isinf(nptensor) & (nptensor < 0)
clipped_tensor = np.clip(nptensor, fp16_min, fp16_max)
finfo = np.finfo(np.float16)
fp16_max = finfo.max
fp16_min = finfo.min
transformed = False

processed_count = 0
for tensor in external_data_helper._get_all_tensors(model):
nptensor = numpy_helper.to_array(tensor) # Removed onnx_base_dir as data is already loaded
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)):
neg_inf_mask = np.isinf(nptensor) & (nptensor < 0)
clipped_tensor = np.clip(nptensor, fp16_min, fp16_max)

# Restore -inf values
if neg_inf_mask.any():
clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor)
# Restore -inf values
if neg_inf_mask.any():
clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor)

new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name)
tensor.CopyFrom(new_tensor)
transformed = True
new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name)
tensor.CopyFrom(new_tensor)
transformed = True

return model, transformed
del neg_inf_mask, clipped_tensor, new_tensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this loop itself you can check and then update flag

del nptensor
processed_count += 1

if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0:
cls._cleanup_memory()

return model, transformed
finally:
# Ensure cleanup happens even if an exception occurs
cls._cleanup_memory()


class SplitTensorsTransform(OnnxTransform):
Expand All @@ -86,16 +172,30 @@ def apply(
:param file_chunk_size: Chunk size to split external files into.
:param size_threshold: Only tensors greater than this threshold (in bytes) will be saved externally.
"""
file_num = 0
current_file_size = 0
transformed = False
external_data_helper.load_external_data_for_model(model, onnx_base_dir)
for tensor in external_data_helper._get_all_tensors(model):
if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold):
transformed = True
current_file_size += tsize
if current_file_size > file_chunk_size:
file_num += 1
current_file_size = tsize
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
return model, transformed
try:
file_num = 0
current_file_size = 0
transformed = False

# --- Adjustment: The initial check and load will now use the new bulk loader ---
# This will either use the cache (if FP16ClipTransform loaded it) or perform the bulk load itself.
cls._load_external_data(model, onnx_base_dir)

processed_count = 0
for tensor in external_data_helper._get_all_tensors(model):
if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold):
transformed = True
current_file_size += tsize
if current_file_size > file_chunk_size:
file_num += 1
current_file_size = tsize
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")

processed_count += 1
if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0:
cls._cleanup_memory()

return model, transformed
finally:
# Ensure cleanup happens even if an exception occurs
cls._cleanup_memory()
1 change: 1 addition & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def get_models_dir():

COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw"]
DEFAULT_AIC_HW_VERSION = "ai100"
ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL = 100

# InternVL constants
# Fixing the feature size with reference to OpenGVLab/InternVL2_5-1B, OpenGVLab/InternVL2_5-38B and OpenGVLab/InternVL2_5-78B
Expand Down
Loading
Loading