diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6ecbf0fc0..9306a6181 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -78,28 +78,71 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: else: logger.info(f"Pytorch transforms applied to model: {self.model_name}") - def _offload_model_weights(self, offload_pt_weights) -> bool: - """ - Clear PyTorch weights after export if offload_pt_weights is set to True - - Returns: - bool: True if weights were successfully offloaded, False otherwise - """ - # Check if offloading is enabled and weights are not already offloaded - if offload_pt_weights and not self._is_weights_offloaded: - try: - self.model = self.model.to_empty(device="meta") - self._is_weights_offloaded = True - logger.info("Model weights offloaded to meta device") - - gc.collect() - logger.info("PyTorch weights cleared after export") - return True + def _clear_model_weights(self) -> None: + """Clear PyTorch model weights to reduce memory usage after ONNX export.""" + try: + # Clear tensor storage and replace with empty shell + for param in self.model.parameters(): + if hasattr(param, "data") and hasattr(param.data, "storage"): + param.data.storage().resize_(0) + + for buffer in self.model.buffers(): + if hasattr(buffer, "data") and hasattr(buffer.data, "storage"): + buffer.data.storage().resize_(0) + + # Clear module dictionaries and hooks + for module in self.model.modules(): + if hasattr(module, "_parameters"): + module._parameters.clear() + if hasattr(module, "_buffers"): + module._buffers.clear() + + # Clear hooks + for hook_dict in [ + getattr(module, "_forward_hooks", {}), + getattr(module, "_forward_pre_hooks", {}), + getattr(module, "_backward_hooks", {}), + getattr(module, "_state_dict_hooks", {}), + getattr(module, "_load_state_dict_pre_hooks", {}), + ]: + hook_dict.clear() + + # Replace with minimal shell for compatibility + class ModelShell: + def __init__(self, config): + self.config = config + self.qaic_config = None + self.device = torch.device("meta") + + def parameters(self): + return iter([]) + + def named_parameters(self): + return iter([]) + + def buffers(self): + return iter([]) + + def named_buffers(self): + return iter([]) + + def modules(self): + return iter([self]) + + def state_dict(self): + return {} + + def to(self, device): + return self + + def eval(self): + return self + + config = getattr(self.model, "config", None) + self.model = ModelShell(config) - except Exception as e: - logger.error(f"Failed to offload model weights: {e}") - return False - return False + except Exception as e: + logger.warning(f"Weight clearing failed, continuing: {e}") def _model_offloaded_check(self) -> None: """ @@ -244,19 +287,32 @@ def _export( try: export_kwargs = {} if export_kwargs is None else export_kwargs - torch.onnx.export( - self.model, - (example_inputs,), - str(tmp_onnx_path), - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=constants.ONNX_EXPORT_OPSET, - **export_kwargs, - ) + + with torch.no_grad(): + torch.onnx.export( + self.model, + (example_inputs,), + str(tmp_onnx_path), + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=constants.ONNX_EXPORT_OPSET, + **export_kwargs, + ) logger.info("PyTorch export successful") - _ = self._offload_model_weights(offload_pt_weights) + # Clear PyTorch weights after successful export to reduce memory usage + if offload_pt_weights: + self._clear_model_weights() + self._is_weights_offloaded = True + logger.info("PyTorch weights cleared after ONNX export") + + # Clear temporary references + example_inputs.clear() + input_names.clear() + + # Force garbage collection + gc.collect() model = onnx.load(tmp_onnx_path, load_external_data=False) transform_kwargs = { @@ -283,6 +339,12 @@ def _export( finally: shutil.rmtree(tmp_onnx_dir, ignore_errors=True) + # Clear external data from memory and cache after all transforms and saving + # Make sure model exists before trying to clean it up + if "model" in locals(): + OnnxTransform._cleanup_external_data_and_cache(model) + OnnxTransform._cleanup_memory() + logger.info("Cleanup complete.") self.onnx_path = onnx_path return onnx_path diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 61b5c00f6..11754e2be 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -5,17 +5,25 @@ # # ---------------------------------------------------------------------------- +import gc +import logging from typing import Optional, Tuple import numpy as np from onnx import ModelProto, external_data_helper, numpy_helper +from QEfficient.utils.constants import ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL + +logger = logging.getLogger(__name__) + class OnnxTransform: """ OnnxTransform is the base class for graph modifications on exported onnx. """ + _external_data_loaded_cache = {} # Dict[int, bool] + def __init__(self): raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.") @@ -31,6 +39,68 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: """ raise NotImplementedError("Use subclasses for ONNX transform") + @classmethod + def _check_external_data_loaded(cls, model: ModelProto) -> bool: + """ + Check if external data is already loaded in the model. + + :param model: The ONNX model to check + :returns: True if external data is already loaded, False otherwise + """ + # Use object ID as key instead of the object itself + model_id = id(model) + # Return cached result if available + if model_id in cls._external_data_loaded_cache: + return cls._external_data_loaded_cache[model_id] + + # Load the model if not already loaded + for tensor in external_data_helper._get_all_tensors(model): + # Check if tensor has external data but no raw data loaded + if len(tensor.external_data) > 0 and not tensor.HasField("raw_data"): + cls._external_data_loaded_cache[model_id] = False + return False + + cls._external_data_loaded_cache[model_id] = True + return True + + @classmethod + def _load_external_data(cls, model: ModelProto, onnx_base_dir: Optional[str] = None): + """ + Performs a bulk load of external data if it's not already loaded. + Updates the cache upon successful load. + """ + model_id = id(model) + if not cls._check_external_data_loaded(model): + logger.info("External data not loaded. Performing bulk load.") + external_data_helper.load_external_data_for_model(model, onnx_base_dir) + cls._external_data_loaded_cache[model_id] = True + else: + logger.info("External data already loaded (or cached). Skipping bulk load.") + + @classmethod + def _cleanup_external_data_and_cache(cls, model: ModelProto): + """ + Combines clearing external data from the model and its cache entry. + """ + # Remove the loaded raw data from tensors + for tensor in external_data_helper._get_all_tensors(model): + if tensor.HasField("raw_data"): + tensor.ClearField("raw_data") + + # Clear the cache entry for this model using its ID + model_id = id(model) + if model_id in cls._external_data_loaded_cache: + del cls._external_data_loaded_cache[model_id] + + logger.info("External data and cache cleaned up.") + + @classmethod + def _cleanup_memory(cls): + """ + Force garbage collection to free up memory after tensor processing. + """ + gc.collect() + class FP16ClipTransform(OnnxTransform): """ @@ -42,26 +112,42 @@ def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwar """ :param onnx_base_dir: Base directory to load tensors """ - finfo = np.finfo(np.float16) - fp16_max = finfo.max - fp16_min = finfo.min - transformed = False + try: + # --- FIX: Ensure external data is loaded efficiently BEFORE processing --- + cls._load_external_data(model, onnx_base_dir) - for tensor in external_data_helper._get_all_tensors(model): - nptensor = numpy_helper.to_array(tensor, onnx_base_dir) - if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)): - neg_inf_mask = np.isinf(nptensor) & (nptensor < 0) - clipped_tensor = np.clip(nptensor, fp16_min, fp16_max) + finfo = np.finfo(np.float16) + fp16_max = finfo.max + fp16_min = finfo.min + transformed = False + + processed_count = 0 + for tensor in external_data_helper._get_all_tensors(model): + nptensor = numpy_helper.to_array(tensor) # Removed onnx_base_dir as data is already loaded + if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)): + neg_inf_mask = np.isinf(nptensor) & (nptensor < 0) + clipped_tensor = np.clip(nptensor, fp16_min, fp16_max) - # Restore -inf values - if neg_inf_mask.any(): - clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor) + # Restore -inf values + if neg_inf_mask.any(): + clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor) - new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name) - tensor.CopyFrom(new_tensor) - transformed = True + new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name) + tensor.CopyFrom(new_tensor) + transformed = True - return model, transformed + del neg_inf_mask, clipped_tensor, new_tensor + + del nptensor + processed_count += 1 + + if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0: + cls._cleanup_memory() + + return model, transformed + finally: + # Ensure cleanup happens even if an exception occurs + cls._cleanup_memory() class SplitTensorsTransform(OnnxTransform): @@ -86,16 +172,30 @@ def apply( :param file_chunk_size: Chunk size to split external files into. :param size_threshold: Only tensors greater than this threshold (in bytes) will be saved externally. """ - file_num = 0 - current_file_size = 0 - transformed = False - external_data_helper.load_external_data_for_model(model, onnx_base_dir) - for tensor in external_data_helper._get_all_tensors(model): - if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold): - transformed = True - current_file_size += tsize - if current_file_size > file_chunk_size: - file_num += 1 - current_file_size = tsize - external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") - return model, transformed + try: + file_num = 0 + current_file_size = 0 + transformed = False + + # --- Adjustment: The initial check and load will now use the new bulk loader --- + # This will either use the cache (if FP16ClipTransform loaded it) or perform the bulk load itself. + cls._load_external_data(model, onnx_base_dir) + + processed_count = 0 + for tensor in external_data_helper._get_all_tensors(model): + if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold): + transformed = True + current_file_size += tsize + if current_file_size > file_chunk_size: + file_num += 1 + current_file_size = tsize + external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") + + processed_count += 1 + if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0: + cls._cleanup_memory() + + return model, transformed + finally: + # Ensure cleanup happens even if an exception occurs + cls._cleanup_memory() diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 92d0b32f2..796c442a0 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -87,6 +87,7 @@ def get_models_dir(): COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw"] DEFAULT_AIC_HW_VERSION = "ai100" +ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL = 100 # InternVL constants # Fixing the feature size with reference to OpenGVLab/InternVL2_5-1B, OpenGVLab/InternVL2_5-38B and OpenGVLab/InternVL2_5-78B diff --git a/scripts/memory_profiling/README.md b/scripts/memory_profiling/README.md new file mode 100644 index 000000000..d24d3155a --- /dev/null +++ b/scripts/memory_profiling/README.md @@ -0,0 +1,181 @@ +# QEfficient Memory Profiling + +A memory profiling solution for QEfficient workflows with manual operation marking. + + + +## Quick Start + +```python +from scripts.memory_profiling import QEffMemoryProfiler +from QEfficient import QEFFAutoModelForCausalLM + +# Initialize profiler +profiler = QEffMemoryProfiler(verbose=True) +profiler.start_monitoring() + +# Your QEfficient workflow +model = QEFFAutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") +model.export() +model.compile(prefill_seq_len=128, ctx_len=256, num_cores=16) +output = model.generate(prompts=["Hello world"]) + +# Generate report and visualization +profiler.stop_monitoring() +print(profiler.get_memory_report()) +profiler.generate_memory_graph("profile.png") +``` + +## Configuration + +### Basic Configuration + +```python +profiler = QEffMemoryProfiler( + sampling_interval=0.1, # Sample every 100ms + output_file="my_profile.png", # Custom output file + verbose=True, # Enable detailed logging + enable_cpu_monitoring=True, # Monitor CPU usage + enable_disk_monitoring=True, # Monitor disk I/O +) +``` + +### Manual Operation Marking + +```python +profiler = QEffMemoryProfiler() +profiler.start_monitoring() + +# Manual operation marking +profiler.mark_operation("Custom Operation 1") +# ... your code ... + +profiler.mark_operation("Custom Operation 2") +# ... more code ... + +profiler.stop_monitoring() +``` + +## API Reference + +### QEffMemoryProfiler + +#### Constructor Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `sampling_interval` | `float` | `0.05` | Time between samples (seconds) | +| `output_file` | `str` | `"qeff_memory_profile.png"` | Output file path | +| `verbose` | `bool` | `False` | Enable verbose logging | +| `enable_cpu_monitoring` | `bool` | `True` | Monitor CPU usage | +| `enable_disk_monitoring` | `bool` | `True` | Monitor disk I/O | + +#### Methods + +- **`start_monitoring()`**: Start background monitoring +- **`stop_monitoring()`**: Stop monitoring and mark completion +- **`mark_operation(name: str)`**: Manually mark operation start +- **`get_memory_report() -> str`**: Generate comprehensive text report +- **`generate_memory_graph(filename: str)`**: Create visualization +- **`stop_and_save(filename: str) -> str`**: Convenience method to stop and save + +#### Properties + +- **`peak_rss`**: Peak RSS memory usage (MB) +- **`peak_operation`**: Operation during peak memory +- **`samples`**: List of collected profiling samples +- **`operations`**: List of marked operations with timestamps + +## Operation Types + +The profiler supports marking these common QEfficient operations: + +- **Model Loading**: `from_pretrained`, `AutoModel`, `AutoTokenizer` +- **Export**: `model.export()`, ONNX transforms, PyTorch transforms +- **Compilation**: `model.compile()`, QNN compilation +- **Generation**: `model.generate()`, inference execution +- **Cleanup**: Memory cleanup, garbage collection + +## Output + +### Console Report +``` +QEFFICIENT PERFORMANCE MONITORING REPORT +============================================================ +Peak Memory Usage: + • RSS (Physical): 18.7 GB at 14:23:45 + • Peak during: Compilation + +Memory Statistics: + • Current RSS: 16.2 GB (Delta: +15.8 GB) + • Duration: 185.3 seconds + • Operations: 4 + +QEfficient Operations Timeline: + 1. 0.0s - Model Loading (25.2s) [+8.2 GB] + 2. 25.2s - Export (15.4s) [+2.1 GB] + 3. 40.6s - Compilation (120.8s) [+6.3 GB] <- Peak + 4. 161.4s - Generation (18.7s) [+1.2 GB] +``` + +### Visualization + +The profiler generates a comprehensive 4-panel visualization: + +1. **Memory Timeline**: RSS usage with colored operation phases +2. **CPU Usage**: CPU utilization with performance zones +3. **Disk I/O**: Read/write activity per operation phase +4. **Phase Duration**: Timing analysis with duration labels + +#### Sample Output + +![Sample Memory Profile](memory_profile_llama3.2.png) + +*Example memory profiling output showing QEfficient workflow phases including model loading, ONNX transforms, compilation, and generation phases with detailed memory, CPU, and disk I/O metrics.* + +## Advanced Usage + + +### Accessing Raw Data + +```python +# Get synchronized data arrays +data = profiler.get_synchronized_data() +timestamps = data['timestamps'] +memory_usage = data['rss_memory'] +cpu_usage = data['cpu_usage'] + +# Access individual samples +for sample in profiler.samples: + print(f"Time: {sample.timestamp}, RSS: {sample.rss_mb} MB") +``` + +## Integration Examples + +### With Existing QEfficient Scripts + +```python +# Add to existing QEfficient workflow +profiler = QEffMemoryProfiler(output_file="workflow_profile.png") +profiler.start_monitoring() + +# Existing QEfficient code unchanged +model = QEFFAutoModelForCausalLM.from_pretrained(model_name) +# ... rest of workflow ... + +# Add at end +report = profiler.stop_and_save() +print(report) +``` + + +## Limitations + +### Disk I/O Tracking + +**Subprocess I/O Limitation**: Disk I/O tracking captures parent process I/O only. Subprocess I/O (e.g., compilation reading ONNX files via `subprocess.run()`) is not captured due to Linux I/O accounting limitations. During compilation phases, expect lower I/O readings than actual file operations performed by subprocesses. + +## Compatibility + +- **Python**: 3.7+ +- **Dependencies**: `psutil`, `matplotlib`, `numpy` diff --git a/scripts/memory_profiling/__init__.py b/scripts/memory_profiling/__init__.py new file mode 100644 index 000000000..dc1377d0b --- /dev/null +++ b/scripts/memory_profiling/__init__.py @@ -0,0 +1,53 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +QEfficient Memory Profiling + +A production-ready memory profiling solution specifically designed for QEfficient workflows. +Provides manual operation marking, comprehensive metrics collection, and professional visualization. + +Usage Example: + +```python +from scripts.memory_profiling import QEffMemoryProfiler + +profiler = QEffMemoryProfiler(verbose=True) +profiler.start_monitoring() +# ... your QEfficient code ... +profiler.stop_monitoring() +print(profiler.get_memory_report()) +profiler.generate_memory_graph() +``` +""" + +__version__ = "2.0.0" +__author__ = "Qualcomm Technologies, Inc." + +# Core profiler components +from .profiler import ( + MetricsCollector, + ProfilerConfig, + ProfileSample, + QEffMemoryProfiler, +) + +# Visualization component (imported on-demand) +try: + from .visualizer import QEffMemoryVisualizer +except ImportError: + # Handle case where matplotlib is not available + QEffMemoryVisualizer = None + +__all__ = [ + "QEffMemoryProfiler", + "ProfilerConfig", + "ProfileSample", + "MetricsCollector", + "QEffMemoryVisualizer", + "__version__", +] diff --git a/scripts/memory_profiling/memory_profile_llama3.2.png b/scripts/memory_profiling/memory_profile_llama3.2.png new file mode 100644 index 000000000..780a43855 Binary files /dev/null and b/scripts/memory_profiling/memory_profile_llama3.2.png differ diff --git a/scripts/memory_profiling/profiler.py b/scripts/memory_profiling/profiler.py new file mode 100644 index 000000000..ba7565d7f --- /dev/null +++ b/scripts/memory_profiling/profiler.py @@ -0,0 +1,729 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +QEfficient Memory Profiler - Production-Ready Memory Monitoring + +This module provides comprehensive memory profiling capabilities specifically +designed for QEfficient workflows. +""" + +import os +import threading +import time +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, List, Optional, Tuple + +import psutil + +from QEfficient.utils.logging_utils import logger + + +@dataclass +class ProfilerConfig: + """Configuration for memory profiler.""" + + sampling_interval: float = 0.2 + output_file: Optional[str] = None + verbose: bool = False + enable_cpu_monitoring: bool = True + enable_disk_monitoring: bool = True + track_child_processes: bool = True + child_scan_interval: float = 1.0 + + +@dataclass +class ProfileSample: + """Single profiling sample containing all metrics.""" + + timestamp: datetime + rss_mb: float + vms_mb: float + cpu_percent: float = 0.0 + disk_read_mb: float = 0.0 + disk_write_mb: float = 0.0 + disk_read_rate: float = 0.0 + disk_write_rate: float = 0.0 + + +class MetricsCollector: + """Handles collection of system metrics with child process support.""" + + def __init__(self, config: ProfilerConfig): + self.config = config + self.process = psutil.Process(os.getpid()) + self._last_disk_counters = None + self._last_disk_time = None + self._cpu_initialized = False + self._last_cpu_ema = 0.0 + self._cpu_ema_alpha = 0.3 + + # Child process tracking + self._track_children = config.track_child_processes + self._child_processes: Dict[int, psutil.Process] = {} + self._last_child_scan = 0.0 + self._child_scan_interval = config.child_scan_interval + self._child_cpu_cache: Dict[int, float] = {} + + if self._track_children and self.config.verbose: + logger.info("Child process tracking enabled") + + def initialize_cpu_monitoring(self) -> None: + """Initialize CPU monitoring.""" + try: + self.process.cpu_percent() # First call to establish baseline + self._cpu_initialized = True + + # Initialize child process CPU monitoring + if self._track_children: + self._update_child_processes() + for child_proc in self._child_processes.values(): + try: + child_proc.cpu_percent() # Initialize baseline for children + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + if self.config.verbose: + logger.info("CPU measurement initialized") + except Exception as e: + if self.config.verbose: + logger.warning(f"CPU initialization warning: {e}") + self._cpu_initialized = False + + def _update_child_processes(self) -> None: + """Discover and track child processes (compilation subprocesses).""" + current_time = time.time() + # Only scan for children if we don't have any, or every 5 seconds + scan_interval = 5.0 if self._child_processes else self._child_scan_interval + if current_time - self._last_child_scan < scan_interval: + return + + try: + # Get current children (recursive to catch subprocess chains) + children = self.process.children(recursive=True) + + # Add new children + new_children_count = 0 + for child in children: + if child.pid not in self._child_processes: + try: + # Verify child is still running and accessible + if child.is_running(): + self._child_processes[child.pid] = child + self._child_cpu_cache[child.pid] = 0.0 + + # Initialize CPU monitoring for new child + try: + child.cpu_percent() # First call to establish baseline + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass # Child may have terminated quickly + + new_children_count += 1 + + if self.config.verbose: + try: + cmd_name = child.name() + logger.info(f"Tracking new subprocess: PID {child.pid} ({cmd_name})") + except (psutil.NoSuchProcess, psutil.AccessDenied): + logger.info(f"Tracking new subprocess: PID {child.pid}") + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + # Remove terminated children + terminated_pids = [] + for pid, proc in self._child_processes.items(): + try: + if not proc.is_running(): + terminated_pids.append(pid) + except (psutil.NoSuchProcess, psutil.AccessDenied): + terminated_pids.append(pid) + + for pid in terminated_pids: + if pid in self._child_processes: + del self._child_processes[pid] + if pid in self._child_cpu_cache: + del self._child_cpu_cache[pid] + if self.config.verbose: + logger.info(f"Removed terminated subprocess: PID {pid}") + + if new_children_count > 0 and self.config.verbose: + logger.info(f"Now tracking {len(self._child_processes)} child processes") + + except Exception as e: + if self.config.verbose: + logger.warning(f"Child process scan error: {e}") + + self._last_child_scan = current_time + + def get_memory_usage(self) -> Tuple[float, float]: + """Get current memory usage in MB (parent + children).""" + try: + # Parent process memory + mem_info = self.process.memory_info() + total_rss = mem_info.rss / 1024 / 1024 + total_vms = mem_info.vms / 1024 / 1024 + + # Add child process memory (if tracking enabled) + if self._track_children: + child_rss = 0.0 + child_vms = 0.0 + active_children = 0 + stale_children = [] + + # Iterate through current child processes + for pid, child_proc in self._child_processes.items(): + try: + child_mem = child_proc.memory_info() + child_rss += child_mem.rss / 1024 / 1024 + child_vms += child_mem.vms / 1024 / 1024 + active_children += 1 + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + # Mark child as stale for cleanup + stale_children.append(pid) + continue + + # Clean up stale children (don't do this during iteration) + for pid in stale_children: + if pid in self._child_processes: + del self._child_processes[pid] + if pid in self._child_cpu_cache: + del self._child_cpu_cache[pid] + + total_rss += child_rss + total_vms += child_vms + + if self.config.verbose and active_children > 0: + logger.debug( + f"Memory: Parent {mem_info.rss / 1024 / 1024:.1f}MB + " + f"Children {child_rss:.1f}MB = Total {total_rss:.1f}MB RSS" + ) + + return total_rss, total_vms + except Exception as e: + if self.config.verbose: + logger.warning(f"Memory collection error: {e}") + return 0.0, 0.0 + + def get_cpu_usage(self) -> float: + """Get CPU usage with child processes included and smoothing.""" + if not self.config.enable_cpu_monitoring: + return 0.0 + + try: + import multiprocessing + + num_cores = multiprocessing.cpu_count() + + parent_cpu_raw = 0.0 + child_cpu_raw_total = 0.0 + + # Parent CPU (raw percentage, can be >100% on multi-core) + if self._cpu_initialized: + parent_cpu_raw = self.process.cpu_percent() + if parent_cpu_raw < 0: + parent_cpu_raw = 0.0 + + # Child CPU (if tracking enabled) + if self._track_children: + active_children = 0 + + for pid, child_proc in list(self._child_processes.items()): + try: + child_cpu_raw = child_proc.cpu_percent() + if child_cpu_raw >= 0: + # Cache raw CPU value + self._child_cpu_cache[pid] = child_cpu_raw + child_cpu_raw_total += child_cpu_raw + active_children += 1 + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + # Use cached value if available, otherwise skip + if pid in self._child_cpu_cache: + child_cpu_raw_total += self._child_cpu_cache[pid] + continue + + if self.config.verbose and active_children > 0: + # Convert to system-wide percentage for logging + parent_system_pct = parent_cpu_raw / num_cores + child_system_pct = child_cpu_raw_total / num_cores + logger.debug( + f"CPU: Parent {parent_system_pct:.1f}% + " + f"Children {child_system_pct:.1f}% (from {active_children} processes) " + f"= {parent_system_pct + child_system_pct:.1f}% system-wide" + ) + + # Calculate system-wide CPU percentage + # psutil.Process.cpu_percent() returns per-process CPU time percentage + # To get system-wide percentage: divide by number of cores + total_process_cpu = parent_cpu_raw + child_cpu_raw_total + system_wide_cpu = total_process_cpu / num_cores + + # Cap at 100% (shouldn't exceed this in normal cases) + system_wide_cpu = min(system_wide_cpu, 100.0) + + # Apply exponential moving average smoothing + if system_wide_cpu > 0 or self._last_cpu_ema > 0: + smoothed_cpu = self._cpu_ema_alpha * system_wide_cpu + (1 - self._cpu_ema_alpha) * self._last_cpu_ema + self._last_cpu_ema = smoothed_cpu + return smoothed_cpu + + return 0.0 + except Exception as e: + if self.config.verbose: + logger.warning(f"CPU collection error: {e}") + return self._last_cpu_ema + + def get_disk_io_stats(self) -> Tuple[float, float, float, float]: + """Get disk I/O statistics with rate calculation (parent + children).""" + if not self.config.enable_disk_monitoring: + return 0.0, 0.0, 0.0, 0.0 + + try: + current_time = time.time() + + # Parent process I/O + parent_io = self.process.io_counters() + + # Determine which counters to use + use_chars = hasattr(parent_io, "read_chars") and hasattr(parent_io, "write_chars") + + if use_chars: + total_read_bytes = parent_io.read_chars + total_write_bytes = parent_io.write_chars + else: + total_read_bytes = parent_io.read_bytes + total_write_bytes = parent_io.write_bytes + + # Add child process I/O (if tracking enabled) + if self._track_children: + child_read_total = 0 + child_write_total = 0 + active_io_children = 0 + + for pid, child_proc in list(self._child_processes.items()): + try: + child_io = child_proc.io_counters() + if use_chars and hasattr(child_io, "read_chars") and hasattr(child_io, "write_chars"): + child_read_total += child_io.read_chars + child_write_total += child_io.write_chars + else: + child_read_total += child_io.read_bytes + child_write_total += child_io.write_bytes + active_io_children += 1 + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + # Child process terminated or inaccessible + continue + + total_read_bytes += child_read_total + total_write_bytes += child_write_total + + if self.config.verbose and active_io_children > 0: + parent_read_mb = ( + parent_io.read_chars / 1024 / 1024 if use_chars else parent_io.read_bytes / 1024 / 1024 + ) + parent_write_mb = ( + parent_io.write_chars / 1024 / 1024 if use_chars else parent_io.write_bytes / 1024 / 1024 + ) + child_read_mb = child_read_total / 1024 / 1024 + child_write_mb = child_write_total / 1024 / 1024 + logger.debug( + f"Disk I/O: Parent R:{parent_read_mb:.1f}MB W:{parent_write_mb:.1f}MB + " + f"Children R:{child_read_mb:.1f}MB W:{child_write_mb:.1f}MB " + f"(from {active_io_children} processes)" + ) + + # Convert to MB + read_mb = total_read_bytes / 1024 / 1024 + write_mb = total_write_bytes / 1024 / 1024 + + # Calculate rates + read_rate = 0.0 + write_rate = 0.0 + + if self._last_disk_counters is not None and self._last_disk_time is not None: + time_delta = current_time - self._last_disk_time + if time_delta > 0: + # Calculate delta from last measurement + if use_chars: + last_read = self._last_disk_counters.get("read_chars", 0) + last_write = self._last_disk_counters.get("write_chars", 0) + else: + last_read = self._last_disk_counters.get("read_bytes", 0) + last_write = self._last_disk_counters.get("write_bytes", 0) + + read_delta = (total_read_bytes - last_read) / 1024 / 1024 # MB + write_delta = (total_write_bytes - last_write) / 1024 / 1024 # MB + + read_rate = read_delta / time_delta # MB/s + write_rate = write_delta / time_delta # MB/s + + # Update counters (store as dict to handle both counter types) + if use_chars: + self._last_disk_counters = {"read_chars": total_read_bytes, "write_chars": total_write_bytes} + else: + self._last_disk_counters = {"read_bytes": total_read_bytes, "write_bytes": total_write_bytes} + self._last_disk_time = current_time + + return read_mb, write_mb, read_rate, write_rate + + except Exception as e: + if self.config.verbose: + logger.warning(f"Disk I/O collection error: {e}") + return 0.0, 0.0, 0.0, 0.0 + + def collect_sample(self) -> ProfileSample: + """Collect a complete profiling sample.""" + timestamp = datetime.now() + rss_mb, vms_mb = self.get_memory_usage() + cpu_percent = self.get_cpu_usage() + read_bytes, write_bytes, read_rate, write_rate = self.get_disk_io_stats() + + return ProfileSample( + timestamp=timestamp, + rss_mb=rss_mb, + vms_mb=vms_mb, + cpu_percent=cpu_percent, + disk_read_mb=read_bytes, + disk_write_mb=write_bytes, + disk_read_rate=read_rate, + disk_write_rate=write_rate, + ) + + +class QEffMemoryProfiler: + """ + Production-ready memory profiler for QEfficient workflows. + + Features: + - Manual operation marking for QEfficient workflows + - Production-quality visualization with detailed segment analysis + - Precise memory attribution and performance metrics + - Professional-grade reporting suitable for debugging and optimization + """ + + # Segment colors for visualization + SEGMENT_COLORS = { + "Initialization": "#E8E8E8", + "Model Loading": "#FF6B6B", + "Export": "#FFEAA7", + "Model Export": "#FFEAA7", + "Compilation": "#98D8C8", + "Model Compilation": "#98D8C8", + "Generation": "#F7DC6F", + "Text Generation": "#F7DC6F", + "Cleanup": "#AED6F1", + "Completion": "#D5DBDB", + } + + def __init__( + self, sampling_interval: float = 0.05, output_file: Optional[str] = None, verbose: bool = False, **kwargs + ): + """ + Initialize the QEfficient Memory Profiler. + + Args: + sampling_interval: Time between memory samples in seconds + output_file: Output file for memory profile graph + verbose: Enable verbose output for monitoring operations + """ + # Create configuration + self.config = ProfilerConfig( + sampling_interval=sampling_interval, + output_file=output_file or "qeff_memory_profile.png", + verbose=verbose, + **kwargs, + ) + + # Initialize components + self.metrics_collector = MetricsCollector(self.config) + + # Monitoring state + self.monitoring = False + self.monitor_thread = None + + # self.samples = deque(maxlen=5000) # Auto-evicts old samples + self.samples: List[ProfileSample] = [] # This could slow down for very long runs + self.operations: List[Tuple[datetime, str]] = [] + + # Peak tracking + self.peak_rss = 0.0 + self.peak_vms = 0.0 + self.peak_rss_time: Optional[datetime] = None + self.peak_vms_time: Optional[datetime] = None + self.peak_operation: Optional[str] = None + + # Operation tracking + self.current_operation = "Initialization" + self.operation_start_time = datetime.now() + self.operation_durations: Dict[str, float] = {} + self.operation_memory_deltas: Dict[str, float] = {} + + # Legacy property accessors for backward compatibility + @property + def timestamps(self) -> List[datetime]: + """Get timestamps from samples.""" + return [sample.timestamp for sample in self.samples] + + @property + def rss_memory(self) -> List[float]: + """Get RSS memory values from samples.""" + return [sample.rss_mb for sample in self.samples] + + @property + def vms_memory(self) -> List[float]: + """Get VMS memory values from samples.""" + return [sample.vms_mb for sample in self.samples] + + @property + def cpu_usage(self) -> List[float]: + """Get CPU usage values from samples.""" + return [sample.cpu_percent for sample in self.samples] + + @property + def disk_read_bytes(self) -> List[float]: + """Get disk read bytes from samples.""" + return [sample.disk_read_mb for sample in self.samples] + + @property + def disk_write_bytes(self) -> List[float]: + """Get disk write bytes from samples.""" + return [sample.disk_write_mb for sample in self.samples] + + @property + def disk_read_rate(self) -> List[float]: + """Get disk read rates from samples.""" + return [sample.disk_read_rate for sample in self.samples] + + @property + def disk_write_rate(self) -> List[float]: + """Get disk write rates from samples.""" + return [sample.disk_write_rate for sample in self.samples] + + @property + def sampling_interval(self) -> float: + """Get sampling interval.""" + return self.config.sampling_interval + + @property + def output_file(self) -> str: + """Get output file path.""" + return self.config.output_file + + @property + def verbose(self) -> bool: + """Get verbose flag.""" + return self.config.verbose + + def start_monitoring(self) -> None: + """Start continuous memory monitoring in background thread.""" + if self.monitoring: + return + + # Initialize CPU measurement + self.metrics_collector.initialize_cpu_monitoring() + + self.monitoring = True + self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True) + self.monitor_thread.start() + + if self.config.verbose: + logger.info(f"QEff Memory monitoring started (sampling every {self.config.sampling_interval}s)") + + def stop_monitoring(self) -> None: + """Stop memory monitoring and generate reports.""" + if not self.monitoring: + return + + self.monitoring = False + if self.monitor_thread: + self.monitor_thread.join(timeout=1.0) + + # Mark completion + self.mark_operation("Completion") + + if self.config.verbose: + logger.info("QEff Memory monitoring stopped") + + def _monitor_loop(self) -> None: + """Background monitoring loop.""" + while self.monitoring: + try: + # Update child processes periodically (throttled internally) + if self.metrics_collector._track_children: + self.metrics_collector._update_child_processes() + + # Collect sample + sample = self.metrics_collector.collect_sample() + self.samples.append(sample) + + # Update peaks + self._update_peaks(sample) + + time.sleep(self.config.sampling_interval) + + except Exception as e: + if self.config.verbose: + logger.warning(f"Monitoring error: {e}") + break + + def _update_peaks(self, sample: ProfileSample) -> None: + """Update peak memory tracking.""" + if sample.rss_mb > self.peak_rss: + self.peak_rss = sample.rss_mb + self.peak_rss_time = sample.timestamp + self.peak_operation = self.current_operation + + if sample.vms_mb > self.peak_vms: + self.peak_vms = sample.vms_mb + self.peak_vms_time = sample.timestamp + + def mark_operation(self, operation_name: str) -> None: + """Mark the start of a new operation.""" + current_time = datetime.now() + current_rss = self.samples[-1].rss_mb if self.samples else 0.0 + + # Record previous operation duration and memory delta + if self.current_operation != "Initialization" and self.samples: + duration = (current_time - self.operation_start_time).total_seconds() + self.operation_durations[self.current_operation] = duration + + # Calculate memory delta from start of operation + start_idx = max(0, len(self.samples) - max(1, int(duration / self.config.sampling_interval))) + start_rss = self.samples[start_idx].rss_mb if start_idx < len(self.samples) else current_rss + memory_delta = current_rss - start_rss + self.operation_memory_deltas[self.current_operation] = memory_delta + + # Start new operation + self.current_operation = operation_name + self.operation_start_time = current_time + self.operations.append((current_time, operation_name)) + + if self.config.verbose: + logger.info(f"{operation_name} | Memory: {current_rss:.1f} MB RSS") + + def get_synchronized_data(self) -> Dict[str, List[float]]: + """Get synchronized data arrays.""" + if not self.samples: + return {} + + start_time = self.samples[0].timestamp + return { + "timestamps": [(s.timestamp - start_time).total_seconds() for s in self.samples], + "rss_memory": [s.rss_mb for s in self.samples], + "vms_memory": [s.vms_mb for s in self.samples], + "cpu_usage": [s.cpu_percent for s in self.samples], + "disk_read_bytes": [s.disk_read_mb for s in self.samples], + "disk_write_bytes": [s.disk_write_mb for s in self.samples], + "disk_read_rate": [s.disk_read_rate for s in self.samples], + "disk_write_rate": [s.disk_write_rate for s in self.samples], + } + + def mark_segment(self, segment_name: str) -> None: + """Convenience method for manual segment marking (API mode).""" + self.mark_operation(segment_name) + + def stop_and_save(self, filename: Optional[str] = None) -> str: + """Stop monitoring and save results (API mode convenience).""" + self.stop_monitoring() + self.generate_memory_graph(filename) + return self.get_memory_report() + + def get_memory_report(self) -> str: + """Generate comprehensive memory usage report.""" + if not self.samples: + return "No memory data collected" + + current_sample = self.samples[-1] + initial_sample = self.samples[0] + + # Calculate statistics + rss_values = [s.rss_mb for s in self.samples] + avg_rss = sum(rss_values) / len(rss_values) + max_rss = max(rss_values) + min_rss = min(rss_values) + + # Auto-scale units + rss_scale, rss_unit = (1024, "GB") if max_rss > 2048 else (1, "MB") + + # Calculate disk I/O statistics + disk_io_stats = "" + if self.samples and len(self.samples) > 1: + total_read = current_sample.disk_read_mb - initial_sample.disk_read_mb + total_write = current_sample.disk_write_mb - initial_sample.disk_write_mb + max_read_rate = max(s.disk_read_rate for s in self.samples) + max_write_rate = max(s.disk_write_rate for s in self.samples) + avg_read_rate = sum(s.disk_read_rate for s in self.samples) / len(self.samples) + avg_write_rate = sum(s.disk_write_rate for s in self.samples) / len(self.samples) + + disk_io_stats = f""" +Disk I/O Statistics: + • Total Read: {total_read:.2f} MB + • Total Write: {total_write:.2f} MB + • Peak Read Rate: {max_read_rate:.2f} MB/s + • Peak Write Rate:{max_write_rate:.2f} MB/s + • Avg Read Rate: {avg_read_rate:.2f} MB/s + • Avg Write Rate: {avg_write_rate:.2f} MB/s""" + + report = f""" +QEFFICIENT PERFORMANCE MONITORING REPORT +{"=" * 60} +Peak Memory Usage: + • RSS (Physical): {self.peak_rss / rss_scale:.2f} {rss_unit} at {self.peak_rss_time.strftime("%H:%M:%S") if self.peak_rss_time else "N/A"} + • VMS (Virtual): {self.peak_vms / rss_scale:.2f} {rss_unit} at {self.peak_vms_time.strftime("%H:%M:%S") if self.peak_vms_time else "N/A"} + • Peak during: {self.peak_operation} + +Memory Statistics: + • Current RSS: {current_sample.rss_mb / rss_scale:.2f} {rss_unit} (Delta: {(current_sample.rss_mb - initial_sample.rss_mb) / rss_scale:+.2f} {rss_unit}) + • Current VMS: {current_sample.vms_mb / rss_scale:.2f} {rss_unit} (Delta: {(current_sample.vms_mb - initial_sample.vms_mb) / rss_scale:+.2f} {rss_unit}) + • Average RSS: {avg_rss / rss_scale:.2f} {rss_unit} + • Min/Max RSS: {min_rss / rss_scale:.2f} / {max_rss / rss_scale:.2f} {rss_unit} + • Memory Range: {(max_rss - min_rss) / rss_scale:.2f} {rss_unit}{disk_io_stats} + +Monitoring Info: + • Duration: {(current_sample.timestamp - initial_sample.timestamp).total_seconds():.1f} seconds + • Data Points: {len(self.samples)} + • Operations: {len(self.operations)} + • Sampling Rate: {self.config.sampling_interval}s + +QEfficient Operations Timeline:""" + + # Add operation timeline + if self.operations: + start_time = self.samples[0].timestamp + for i, (op_time, op_name) in enumerate(self.operations): + relative_time = (op_time - start_time).total_seconds() + duration = self.operation_durations.get(op_name, 0) + memory_delta = self.operation_memory_deltas.get(op_name, 0) + + duration_str = f"({duration:.1f}s)" if duration > 0 else "" + memory_str = f"[{memory_delta / rss_scale:+.1f} {rss_unit}]" if abs(memory_delta) > 10 else "" + + report += f"\n {i + 1:2d}. {relative_time:6.1f}s - {op_name} {duration_str} {memory_str}" + + return report + + def generate_memory_graph(self, filename: Optional[str] = None) -> None: + """Generate professional memory usage graph with QEfficient operation segments.""" + if not self.samples: + logger.warning("No data to plot") + return + + output_file = filename or self.config.output_file + + # Import visualization module + from .visualizer import QEffMemoryVisualizer + + visualizer = QEffMemoryVisualizer(self) + visualizer.generate_professional_graph(output_file) + + if self.config.verbose: + logger.info(f"QEfficient memory profile saved as: {output_file}") + + # Legacy methods for backward compatibility + def get_memory_usage(self) -> Tuple[float, float]: + """Get current memory usage in MB (legacy method).""" + return self.metrics_collector.get_memory_usage() diff --git a/scripts/memory_profiling/visualizer.py b/scripts/memory_profiling/visualizer.py new file mode 100644 index 000000000..c16c0c0ef --- /dev/null +++ b/scripts/memory_profiling/visualizer.py @@ -0,0 +1,604 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +QEfficient Memory Visualizer - Production Quality Enhanced Visualization + +This module provides production-quality visualization with detailed segment analysis, +clear operation boundaries, and comprehensive memory metrics. +""" + +from datetime import datetime +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np + +if TYPE_CHECKING: + from .profiler import QEffMemoryProfiler + +from QEfficient.utils.logging_utils import logger + + +class QEffMemoryVisualizer: + """Production-quality memory visualization with enhanced segment analysis.""" + + def __init__(self, profiler: "QEffMemoryProfiler"): + """Initialize visualizer with profiler data.""" + self.profiler = profiler + self._setup_matplotlib_style() + + def _setup_matplotlib_style(self) -> None: + """Configure matplotlib for professional styling.""" + plt.style.use("default") + plt.rcParams.update( + { + "font.size": 10, + "font.family": ["DejaVu Sans", "sans-serif"], + "axes.linewidth": 1.2, + "figure.facecolor": "white", + "axes.facecolor": "white", + "grid.alpha": 0.3, + "lines.linewidth": 2.0, + "axes.spines.top": False, + "axes.spines.right": False, + "axes.edgecolor": "#333333", + "text.color": "#333333", + "axes.labelcolor": "#333333", + "xtick.color": "#333333", + "ytick.color": "#333333", + } + ) + + def generate_professional_graph(self, filename: str) -> None: + """Generate enhanced multi-panel memory profile with synchronized visualization.""" + if not self.profiler.samples: + logger.warning("No data to plot") + return + + # Get synchronized data + sync_data = self.profiler.get_synchronized_data() + + # Create figure with professional layout - Fixed spacing to prevent title overlap + fig = plt.figure(figsize=(20, 12), facecolor="white") + gs = fig.add_gridspec( + 3, + 2, + height_ratios=[2.5, 1.8, 1.2], + width_ratios=[1, 1], + hspace=0.35, + wspace=0.2, + left=0.05, + right=0.98, + top=0.90, + bottom=0.08, + ) + + # Create subplots + ax_memory = fig.add_subplot(gs[0, :]) # Memory usage (full width) + ax_cpu = fig.add_subplot(gs[1, :]) # CPU usage (full width) + ax_disk = fig.add_subplot(gs[2, 0]) # Disk I/O (left) + ax_timing = fig.add_subplot(gs[2, 1]) # Phase Duration (right) + + # Prepare data + relative_times = sync_data["timestamps"] + max_rss = max(sync_data["rss_memory"]) if sync_data["rss_memory"] else 0 + use_gb = max_rss > 2048 + scale = 1024 if use_gb else 1 + unit = "GB" if use_gb else "MB" + rss_scaled = [x / scale for x in sync_data["rss_memory"]] + + # Normalize CPU usage to prevent > 100% values (multi-core issue) + normalized_cpu = [min(cpu, 100.0) for cpu in sync_data["cpu_usage"]] + + # Setup plots + self._setup_memory_plot(ax_memory, relative_times, rss_scaled, scale, unit) + self._setup_cpu_plot(ax_cpu, relative_times, normalized_cpu) + self._setup_disk_io_plot(ax_disk, sync_data) + self._setup_timing_plot(ax_timing) + + # Add main title with proper spacing + fig.suptitle( + "QEfficient Enhanced Memory & Performance Analysis - Synchronized View", + fontsize=18, + fontweight="bold", + color="#2E86AB", + y=0.95, + ) + + # Save with high quality + plt.savefig( + filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none", format="png", pad_inches=0.2 + ) + plt.close() + + logger.info(f"Enhanced synchronized memory profile saved: {filename}") + + def _setup_memory_plot( + self, ax, relative_times: List[float], rss_scaled: List[float], scale: float, unit: str + ) -> None: + """Setup the main memory usage plot with enhanced visualization.""" + if not relative_times or not rss_scaled: + ax.text( + 0.5, + 0.5, + "No memory data available", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=12, + color="#666666", + ) + return + + start_time = self.profiler.samples[0].timestamp + + # Draw segment backgrounds + self._draw_segment_backgrounds(ax, relative_times, rss_scaled, start_time) + + # Main memory line + ax.plot( + relative_times, rss_scaled, color="#2E86AB", linewidth=3.5, label="Memory Usage (RSS)", alpha=0.9, zorder=5 + ) + ax.fill_between(relative_times, rss_scaled, alpha=0.15, color="#2E86AB", zorder=1) + + # Add segment boundaries and annotations + self._draw_segment_boundaries(ax, start_time, max(rss_scaled)) + self._mark_peak_memory(ax, start_time, scale, unit) + + # Format axes + ax.set_xlabel("Time (seconds)", fontsize=13, fontweight="bold") + ax.set_ylabel(f"Memory Usage ({unit})", fontsize=13, fontweight="bold") + ax.set_xlim(0, max(relative_times) * 1.02) + ax.set_ylim(0, max(rss_scaled) * 1.15) + ax.grid(True, alpha=0.3, linestyle="-", linewidth=0.8, color="#CCCCCC") + ax.set_axisbelow(True) + + # Enhanced title + total_duration = relative_times[-1] if relative_times else 0 + peak_memory = max(rss_scaled) if rss_scaled else 0 + ax.set_title( + f"Memory Usage Over Time | Peak: {peak_memory:.1f} {unit} | Duration: {total_duration:.1f}s", + fontsize=14, + fontweight="bold", + color="#2E86AB", + pad=15, + ) + + # Add legend + self._add_segment_legend(ax) + + def _setup_cpu_plot(self, ax, relative_times: List[float], cpu_usage: List[float]) -> None: + """Setup CPU plot with perfect synchronization to memory plot.""" + if not relative_times or not cpu_usage or len(cpu_usage) != len(relative_times): + ax.text( + 0.5, + 0.5, + "CPU data not available or not synchronized", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=12, + color="#666666", + ) + ax.set_title("CPU Usage Over Time", fontsize=14, fontweight="bold") + if relative_times: + ax.set_xlim(0, max(relative_times) * 1.02) + return + + start_time = self.profiler.samples[0].timestamp + + # Draw segment backgrounds for consistency + self._draw_segment_backgrounds(ax, relative_times, cpu_usage, start_time, max_val=100) + + # Main CPU line + ax.plot(relative_times, cpu_usage, color="#FF6B35", linewidth=3, label="CPU Usage", alpha=0.9, zorder=5) + ax.fill_between(relative_times, cpu_usage, alpha=0.2, color="#FF6B35", zorder=1) + + # Add segment boundaries + self._draw_segment_boundaries(ax, start_time, max(cpu_usage) if cpu_usage else 100) + + # Add average line + avg_cpu = sum(cpu_usage) / len(cpu_usage) + ax.axhline( + y=avg_cpu, + color="#E74C3C", + linestyle="-", + alpha=0.8, + linewidth=2.5, + label=f"Average: {avg_cpu:.1f}%", + zorder=4, + ) + + # Add performance zones + ax.axhspan(0, 25, alpha=0.08, color="#4CAF50", zorder=0) + ax.axhspan(25, 50, alpha=0.08, color="#FFC107", zorder=0) + ax.axhspan(50, 75, alpha=0.08, color="#FF9800", zorder=0) + ax.axhspan(75, 100, alpha=0.08, color="#F44336", zorder=0) + + # Format axes + ax.set_ylabel("CPU Usage (%)", fontsize=13, fontweight="bold") + ax.set_xlabel("Time (seconds)", fontsize=12, fontweight="bold") + ax.set_xlim(0, max(relative_times) * 1.02) + ax.set_ylim(0, max(cpu_usage) * 1.1 if cpu_usage else 100) + ax.grid(True, alpha=0.3, linestyle="-", linewidth=0.8, color="#CCCCCC") + ax.set_axisbelow(True) + + # Enhanced title + max_cpu = max(cpu_usage) + ax.set_title( + f"CPU Usage Over Time | Peak: {max_cpu:.1f}% | Average: {avg_cpu:.1f}%", + fontsize=14, + fontweight="bold", + color="#FF6B35", + pad=15, + ) + + # Compact legend + ax.legend(loc="upper right", fontsize=10, framealpha=0.9) + + def _setup_disk_io_plot(self, ax, sync_data: Dict[str, List[float]]) -> None: + """Setup enhanced disk I/O plot showing phase-based analysis.""" + if not self.profiler.operations or len(self.profiler.operations) < 2: + ax.text( + 0.5, + 0.5, + "No operation phases available", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=12, + color="#666666", + ) + ax.set_title("Disk I/O per Phase", fontsize=14, fontweight="bold") + return + + # Calculate I/O per phase + operations, read_totals, write_totals = self._calculate_io_per_phase(sync_data) + + if not operations: + ax.text( + 0.5, + 0.5, + "No significant disk I/O detected", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=12, + color="#666666", + ) + ax.set_title("Disk I/O per Phase", fontsize=14, fontweight="bold") + return + + # Create enhanced bar chart + x_pos = np.arange(len(operations)) + bar_width = 0.35 + + bars_read = ax.bar( + x_pos - bar_width / 2, + read_totals, + bar_width, + label="Read (MB)", + color="#2196F3", + alpha=0.8, + edgecolor="white", + linewidth=1.5, + ) + bars_write = ax.bar( + x_pos + bar_width / 2, + write_totals, + bar_width, + label="Write (MB)", + color="#FF5722", + alpha=0.8, + edgecolor="white", + linewidth=1.5, + ) + + # Add value labels + self._add_bar_labels(ax, bars_read, bars_write, read_totals, write_totals) + + # Format axes + ax.set_ylabel("Total I/O (MB)", fontsize=12, fontweight="bold") + ax.set_xlabel("Operation Phase", fontsize=11, fontweight="bold") + ax.set_xticks(x_pos) + ax.set_xticklabels(operations, rotation=45, ha="right", fontsize=10) + + max_val = max(max(read_totals) if read_totals else [0], max(write_totals) if write_totals else [0]) + ax.set_ylim(0, max_val * 1.25 if max_val > 0 else 1) + ax.grid(True, alpha=0.3, linestyle="-", linewidth=0.5, color="#CCCCCC", axis="y") + ax.set_title("Disk I/O per Operation Phase", fontsize=14, fontweight="bold", pad=15) + ax.legend(loc="upper right", fontsize=10, framealpha=0.9) + + # Summary statistics + total_read = sum(read_totals) + total_write = sum(write_totals) + ax.text( + 0.02, + 0.98, + f"Total I/O: {total_read:.1f} MB read, {total_write:.1f} MB write", + transform=ax.transAxes, + fontsize=10, + va="top", + ha="left", + bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.9, edgecolor="gray", linewidth=1), + ) + + def _setup_timing_plot(self, ax) -> None: + """Setup enhanced timing analysis plot.""" + operations, durations, colors = self._get_timing_data() + + if not operations: + ax.text( + 0.5, + 0.5, + "No timing data available", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=12, + color="#666666", + ) + ax.set_title("Phase Duration Analysis", fontsize=14, fontweight="bold") + return + + # Enhanced horizontal bar chart + y_pos = np.arange(len(operations)) + bars = ax.barh(y_pos, durations, color=colors, alpha=0.8, edgecolor="white", linewidth=1.5, height=0.6) + + # Add duration labels + self._add_duration_labels(ax, bars, durations) + + # Format axes + ax.set_yticks(y_pos) + ax.set_yticklabels(operations, fontsize=11) + ax.set_xlabel("Duration (seconds)", fontsize=12, fontweight="bold") + ax.set_title("Phase Duration Analysis", fontsize=14, fontweight="bold", pad=15) + ax.grid(True, alpha=0.3, linestyle="-", linewidth=0.5, color="#CCCCCC", axis="x") + ax.set_xlim(0, max(durations) * 1.2) + + # Add total duration summary + total_duration = sum(durations) + ax.text( + 0.98, + 0.02, + f"Total: {total_duration:.1f}s", + transform=ax.transAxes, + fontsize=10, + va="bottom", + ha="right", + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.9, edgecolor="gray", linewidth=1), + ) + + def _draw_segment_backgrounds( + self, + ax, + relative_times: List[float], + data_values: List[float], + start_time: datetime, + max_val: Optional[float] = None, + ) -> None: + """Draw colored background segments for each operation.""" + if len(self.profiler.operations) < 2: + return + + max_value = max_val or (max(data_values) * 1.1 if data_values else 100) + + for i in range(len(self.profiler.operations) - 1): + op_start_time = (self.profiler.operations[i][0] - start_time).total_seconds() + op_end_time = (self.profiler.operations[i + 1][0] - start_time).total_seconds() + op_name = self.profiler.operations[i][1] + + color = self.profiler.SEGMENT_COLORS.get(op_name, "#F0F0F0") + + rect = patches.Rectangle( + (op_start_time, 0), + op_end_time - op_start_time, + max_value, + linewidth=0, + facecolor=color, + alpha=0.15, + zorder=0, + ) + ax.add_patch(rect) + + def _draw_segment_boundaries(self, ax, start_time: datetime, max_value: float) -> None: + """Draw vertical lines at segment boundaries.""" + for i, (op_time, op_name) in enumerate(self.profiler.operations): + if i == 0: + continue + + boundary_time = (op_time - start_time).total_seconds() + ax.axvline(x=boundary_time, color="#666666", linestyle="--", alpha=0.6, linewidth=2, zorder=3) + + def _mark_peak_memory(self, ax, start_time: datetime, scale: float, unit: str) -> None: + """Mark peak memory with enhanced annotation.""" + if not self.profiler.peak_rss_time: + return + + peak_time_rel = (self.profiler.peak_rss_time - start_time).total_seconds() + peak_rss_scaled = self.profiler.peak_rss / scale + + # Enhanced peak marker + ax.plot( + peak_time_rel, + peak_rss_scaled, + "o", + color="#E74C3C", + markersize=14, + markeredgecolor="white", + markeredgewidth=3, + zorder=10, + label="Peak Memory", + ) + + # Enhanced annotation + peak_text = f"Peak: {peak_rss_scaled:.1f} {unit}\nPhase: {self.profiler.peak_operation}" + ax.annotate( + peak_text, + xy=(peak_time_rel, peak_rss_scaled), + xytext=(25, 25), + textcoords="offset points", + bbox=dict(boxstyle="round,pad=0.6", facecolor="#E74C3C", alpha=0.95, edgecolor="white", linewidth=2), + arrowprops=dict(arrowstyle="->", color="#E74C3C", lw=2.5), + fontsize=11, + fontweight="bold", + color="white", + ha="left", + va="bottom", + zorder=15, + ) + + def _add_segment_legend(self, ax) -> None: + """Add enhanced segment legend with better styling.""" + if not self.profiler.operations: + return + + unique_ops = [] + seen_ops = set() + for _, op_name in self.profiler.operations: + if op_name not in seen_ops and op_name not in ["Initialization", "Completion"]: + unique_ops.append(op_name) + seen_ops.add(op_name) + + if not unique_ops: + return + + legend_elements = [] + for op_name in unique_ops: + color = self.profiler.SEGMENT_COLORS.get(op_name, "#666666") + duration = self.profiler.operation_durations.get(op_name, 0) + + label = f"{op_name} ({duration:.1f}s)" if duration > 0 else op_name + legend_elements.append(patches.Patch(color=color, alpha=0.8, label=label)) + + legend = ax.legend( + handles=legend_elements, + loc="upper left", + bbox_to_anchor=(1.01, 1.0), + fontsize=11, + title="QEfficient Phases", + title_fontsize=12, + framealpha=0.95, + edgecolor="#2E86AB", + fancybox=True, + ) + legend.get_frame().set_facecolor("#F8F9FA") + + def _calculate_io_per_phase(self, sync_data: Dict[str, List[float]]) -> Tuple[List[str], List[float], List[float]]: + """Calculate I/O totals per operation phase.""" + operations = [] + read_totals = [] + write_totals = [] + + valid_operations = [ + (op_time, op_name) + for op_time, op_name in self.profiler.operations + if op_name not in ["Initialization", "Completion"] + ] + + if not valid_operations: + return operations, read_totals, write_totals + + relative_times = sync_data["timestamps"] + start_time = self.profiler.samples[0].timestamp + + for i, (op_time, op_name) in enumerate(valid_operations): + op_start_time = (op_time - start_time).total_seconds() + + if i + 1 < len(valid_operations): + op_end_time = (valid_operations[i + 1][0] - start_time).total_seconds() + else: + op_end_time = max(relative_times) if relative_times else op_start_time + 1 + + # Find data indices + start_idx = next((j for j, t in enumerate(relative_times) if t >= op_start_time), 0) + end_idx = next((j for j, t in enumerate(relative_times) if t >= op_end_time), len(relative_times) - 1) + + if start_idx < len(sync_data["disk_read_bytes"]) and end_idx < len(sync_data["disk_read_bytes"]): + read_total = sync_data["disk_read_bytes"][end_idx] - sync_data["disk_read_bytes"][start_idx] + write_total = sync_data["disk_write_bytes"][end_idx] - sync_data["disk_write_bytes"][start_idx] + + if read_total > 0.01 or write_total > 0.01: + operations.append(op_name) + read_totals.append(max(0, read_total)) + write_totals.append(max(0, write_total)) + + return operations, read_totals, write_totals + + def _get_timing_data(self) -> Tuple[List[str], List[float], List[str]]: + """Get timing data for operations.""" + operations = [] + durations = [] + colors = [] + + for op_time, op_name in self.profiler.operations: + if op_name in ["Initialization", "Completion"]: + continue + duration = self.profiler.operation_durations.get(op_name, 0) + if duration > 0: + operations.append(op_name) + durations.append(duration) + colors.append(self.profiler.SEGMENT_COLORS.get(op_name, "#666666")) + + return operations, durations, colors + + def _add_bar_labels(self, ax, bars_read, bars_write, read_totals: List[float], write_totals: List[float]) -> None: + """Add value labels on bars.""" + max_val = max(max(read_totals) if read_totals else [0], max(write_totals) if write_totals else [0]) + + for i, (read_bar, write_bar, read_val, write_val) in enumerate( + zip(bars_read, bars_write, read_totals, write_totals) + ): + if read_val > 0.01: + ax.text( + read_bar.get_x() + read_bar.get_width() / 2, + read_bar.get_height() + max_val * 0.02, + f"{read_val:.1f}", + ha="center", + va="bottom", + fontsize=9, + fontweight="bold", + color="#2196F3", + ) + + if write_val > 0.01: + ax.text( + write_bar.get_x() + write_bar.get_width() / 2, + write_bar.get_height() + max_val * 0.02, + f"{write_val:.1f}", + ha="center", + va="bottom", + fontsize=9, + fontweight="bold", + color="#FF5722", + ) + + def _add_duration_labels(self, ax, bars, durations: List[float]) -> None: + """Add duration labels on timing bars.""" + max_duration = max(durations) + + for i, (bar, duration) in enumerate(zip(bars, durations)): + width = bar.get_width() + minutes = int(duration // 60) + seconds = duration % 60 + + if minutes > 0: + duration_text = f"{minutes}m {seconds:.1f}s" + else: + duration_text = f"{seconds:.1f}s" + + ax.text( + width + max_duration * 0.02, + bar.get_y() + bar.get_height() / 2, + duration_text, + ha="left", + va="center", + fontsize=10, + fontweight="bold", + )