diff --git a/examples/demo_v02_full.py b/examples/demo_v02_full.py new file mode 100644 index 0000000..73b6273 --- /dev/null +++ b/examples/demo_v02_full.py @@ -0,0 +1,616 @@ +""" +PyGPUkit v0.2 Full Feature Demo & Benchmark + +This demo showcases ALL v0.2 features: +=== Core Infrastructure === +1. Rust Memory Pool - LRU eviction, size classes +2. Rust Scheduler - Task management, bandwidth pacing +3. Rust Async Transfer Engine - Separate H2D/D2H streams +4. Rust Kernel Dispatch Controller - Per-stream launch management + +=== New v0.2 Features === +5. Admission Control - Deterministic admission pipeline +6. QoS Policy Framework - K8s-style QoS tiers +7. Kernel Pacing Engine - Bandwidth-based throttling +8. Micro-Slicing Framework - Kernel splitting for fairness +9. Pinned Memory Support - Page-locked host memory +10. Kernel Cache - PTX caching for NVRTC +11. GPU Partitioning - Resource isolation + +=== Compute === +12. Tiled Matmul Benchmark +""" + +import os +import sys +import time + +# Add CUDA DLLs to PATH +cuda_path = os.environ.get("CUDA_PATH", r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4") +cuda_bin = os.path.join(cuda_path, "bin") +if cuda_bin not in os.environ["PATH"]: + os.environ["PATH"] = cuda_bin + os.pathsep + os.environ["PATH"] +if hasattr(os, "add_dll_directory"): + os.add_dll_directory(cuda_bin) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + +import numpy as np + +# ============================================================================= +# Header +# ============================================================================= + +def print_header(title: str): + print("\n" + "=" * 70) + print(f" {title}") + print("=" * 70) + +def print_section(title: str): + print(f"\n--- {title} ---") + +# ============================================================================= +# Main Demo +# ============================================================================= + +def main(): + print_header("PyGPUkit v0.2 Complete Feature Demo") + + # Import modules + try: + import pygpukit + native = pygpukit._pygpukit_native + import _pygpukit_rust as rust + print(f"PyGPUkit version: {pygpukit.__version__}") + print("Native module loaded: OK") + print("Rust module loaded: OK") + except ImportError as e: + print(f"Import error: {e}") + return 1 + + # Environment info + print_section("Environment") + print(f"CUDA available: {native.is_cuda_available()}") + + if not native.is_cuda_available(): + print("GPU not available, exiting") + return 1 + + props = native.get_device_properties(0) + print(f"GPU: {props.name}") + print(f"Memory: {props.total_memory / 1024**3:.1f} GB") + print(f"Compute Capability: {props.compute_capability_major}.{props.compute_capability_minor}") + print(f"SMs: {props.multiprocessor_count}") + + # ========================================================================= + # 1. Rust Memory Pool Demo + # ========================================================================= + print_header("1. Rust Memory Pool") + + pool = rust.MemoryPool( + quota=100 * 1024 * 1024, # 100 MB quota + enable_eviction=True + ) + print(f"Created pool with 100 MB quota, eviction enabled") + + # Allocate blocks + block_ids = [] + for i, size in enumerate([1024, 4096, 16384, 65536]): + block_id = pool.allocate(size) + block_ids.append(block_id) + block = pool.get_block(block_id) + print(f" Allocated block {block.id}: {block.size} bytes") + + stats = pool.stats() + print(f"\nPool stats:") + print(f" Active: {stats.active_blocks} blocks, {stats.used} bytes") + print(f" Allocations: {stats.allocation_count}") + print(f" Quota usage: {stats.used / stats.quota:.1%}") + + # Free and reuse + pool.free(block_ids[0]) + print(f"\nFreed block {block_ids[0]}") + + new_block_id = pool.allocate(1024) + print(f"Allocated new 1024-byte block: {new_block_id} (should reuse free list)") + + stats = pool.stats() + print(f"Reuse count: {stats.reuse_count}") + + # ========================================================================= + # 2. Rust Scheduler Demo + # ========================================================================= + print_header("2. Rust Scheduler") + + scheduler = rust.Scheduler( + sched_tick_ms=10.0, + window_ms=100.0, + total_memory=1024 * 1024 * 1024 # 1 GB + ) + print("Created scheduler (10ms tick, 100ms window, 1GB memory)") + + # Submit tasks + task_ids = [] + for i in range(5): + task = rust.TaskMeta( + id=f"task_{i}", + name=f"Layer {i}", + memory_estimate=100 * 1024 * 1024, # 100 MB + priority=i % 3 + ) + task_id = scheduler.submit(task) + task_ids.append(task_id) + print(f" Submitted: {task_id} (priority: {i % 3})") + + sched_stats = scheduler.stats() + print(f"\nPending tasks: {sched_stats.pending_count}") + + # Run tasks + runnable_ids = scheduler.get_runnable_tasks(max_tasks=3) + print(f"Runnable tasks: {len(runnable_ids)}") + + # Start and complete a task + if runnable_ids: + scheduler.start_task(runnable_ids[0]) + scheduler.complete_task(runnable_ids[0]) + print(f"Completed: {runnable_ids[0]}") + + # ========================================================================= + # 3. Rust Async Transfer Engine Demo + # ========================================================================= + print_header("3. Rust Async Transfer Engine") + + transfer_engine = rust.AsyncTransferEngine(max_concurrent=4) + print("Created transfer engine (max 4 concurrent)") + + # Queue transfers + for i in range(4): + type_name = "h2d" if i % 2 == 0 else "d2h" + op_id = transfer_engine.enqueue_with_priority( + transfer_type=type_name, + src_ptr=0x1000 + i * 0x1000, + dst_ptr=0x2000 + i * 0x1000, + size=1024 * 1024, + priority=i % 3 + ) + print(f" Queued transfer {op_id}: {type_name.upper()}") + + # Simulate completion + ready = transfer_engine.get_ready_transfers(max_transfers=4) + for op in ready[:2]: + transfer_engine.start_transfer(op.id) + transfer_engine.complete_transfer(op.id) + + transfer_stats = transfer_engine.stats() + print(f"Transfer stats: {transfer_stats.completed_count} completed, {transfer_stats.pending_count} pending") + + # ========================================================================= + # 4. Rust Kernel Dispatch Controller Demo + # ========================================================================= + print_header("4. Rust Kernel Dispatch Controller") + + dispatcher = rust.KernelDispatcher(max_in_flight=4) + print("Created dispatcher (max 4 in-flight per stream)") + + for i in range(4): + config = rust.LaunchConfig( + grid=(128, 1, 1), + block=(256, 1, 1), + shared_mem=0, + stream_id=i % 2 + ) + req_id = dispatcher.queue( + kernel_handle=0xDEADBEEF + i, + config=config, + priority=i % 3 + ) + print(f" Queued kernel {req_id}: stream={i % 2}") + + ready_kernels = dispatcher.get_ready(max_requests=4) + for req in ready_kernels[:2]: + dispatcher.mark_launched(req.id) + dispatcher.mark_completed(req.id) + + dispatch_stats = dispatcher.stats() + print(f"Dispatch stats: {dispatch_stats.completed_count} completed, {dispatch_stats.pending_count} pending") + + # ========================================================================= + # 5. Admission Control (NEW) + # ========================================================================= + print_header("5. Admission Control (NEW)") + + print("Testing admission control with memory and bandwidth limits...") + + # Create tasks and test admission via scheduler + admission_scheduler = rust.Scheduler( + sched_tick_ms=10.0, + window_ms=100.0, + total_memory=500 * 1024 * 1024 # 500 MB limit + ) + + # Submit tasks that should fit + for i in range(3): + task = rust.TaskMeta( + id=f"admit_{i}", + name=f"Admissible Task {i}", + memory_estimate=100 * 1024 * 1024, # 100 MB each + priority=1 + ) + task_id = admission_scheduler.submit(task) + print(f" Admitted task {i}: {task_id}") + + admission_stats = admission_scheduler.stats() + print(f"\nAdmission results:") + print(f" Total submitted: {admission_stats.total_submitted}") + print(f" Reserved memory: {admission_stats.reserved_memory / 1024 / 1024:.0f} MB") + + # ========================================================================= + # 6. QoS Policy Framework (NEW) + # ========================================================================= + print_header("6. QoS Policy Framework (NEW)") + + print("Creating QoS tiers: Guaranteed, Burstable, BestEffort") + + # Create QoS policy evaluator + qos_evaluator = rust.QosPolicyEvaluator( + total_memory=1024 * 1024 * 1024, # 1 GB + total_bandwidth=1.0 + ) + + # Test different QoS classes + qos_tasks = [ + rust.QosTaskMeta.guaranteed("high-priority", "Critical Task", 256 * 1024 * 1024), + rust.QosTaskMeta.burstable("medium-priority", "Normal Task", 128 * 1024 * 1024, 2.0), + rust.QosTaskMeta.best_effort("low-priority", "Background Task"), + ] + + qos_class_names = {0: "Guaranteed", 1: "Burstable", 2: "BestEffort"} + for task in qos_tasks: + eval_result = qos_evaluator.evaluate(task) + class_name = qos_class_names.get(int(task.qos_class), "Unknown") + if eval_result.is_admitted(): + qos_evaluator.reserve(eval_result) + print(f" {class_name:12} | {task.name:15} | ADMITTED (priority={task.effective_priority()})") + elif eval_result.is_throttled(): + print(f" {class_name:12} | {task.name:15} | THROTTLED") + else: + print(f" {class_name:12} | {task.name:15} | QUEUED") + + qos_stats = qos_evaluator.stats() + print(f"\nQoS stats:") + print(f" Guaranteed memory: {qos_stats.guaranteed_memory / 1024 / 1024:.0f} MB") + print(f" Burstable memory: {qos_stats.burstable_memory / 1024 / 1024:.0f} MB") + print(f" Available memory: {qos_stats.available_memory / 1024 / 1024:.0f} MB") + print(f" Best effort queue: {qos_stats.best_effort_queue}") + + # ========================================================================= + # 7. Kernel Pacing Engine (NEW) + # ========================================================================= + print_header("7. Kernel Pacing Engine (NEW)") + + pacing_config = rust.PacingConfig( + total_bandwidth=1.0, + window_ms=100.0, + min_interval_ms=0.1, + adaptive=True + ) + pacing_engine = rust.KernelPacingEngine(pacing_config) + print(f"Created pacing engine: {pacing_config}") + + # Allocate bandwidth to streams + pacing_engine.allocate_stream(0, 0.6) # 60% to stream 0 + pacing_engine.allocate_stream(1, 0.3) # 30% to stream 1 + print(f"\nAllocated bandwidth: stream 0=60%, stream 1=30%") + + # Test launch decisions + for stream_id in [0, 1, 2]: # 2 is unknown + decision = pacing_engine.should_launch(stream_id) + if decision.can_launch(): + pacing_engine.record_launch(stream_id) + print(f" Stream {stream_id}: LAUNCH") + elif decision.is_throttled(): + pacing_engine.record_throttle(stream_id) + print(f" Stream {stream_id}: THROTTLED ({decision.decision_type})") + else: + print(f" Stream {stream_id}: WAIT {decision.wait_ms():.2f}ms") + + pacing_stats = pacing_engine.stats() + print(f"\nPacing stats:") + print(f" Streams: {pacing_stats.stream_count}") + print(f" Used bandwidth: {pacing_stats.used_bandwidth:.1%}") + print(f" Total launches: {pacing_stats.total_launches}") + + # ========================================================================= + # 8. Micro-Slicing Framework (NEW) + # ========================================================================= + print_header("8. Micro-Slicing Framework (NEW)") + + slice_config = rust.SliceConfig( + max_items_per_slice=10000, + max_duration_ms=1.0, + min_slices=2, + max_slices=16, + adaptive=True + ) + slice_scheduler = rust.SliceScheduler(slice_config) + print(f"Created slice scheduler: {slice_config}") + + # Submit kernels for slicing + num_slices_1 = slice_scheduler.submit( + kernel_handle=0xAAAA0001, + total_items=50000, + block=(256, 1, 1), + shared_mem=0 + ) + print(f"\nKernel 1: 50000 items -> {num_slices_1} slices") + + num_slices_2 = slice_scheduler.submit_for_task( + task_id="high-priority", + kernel_handle=0xAAAA0002, + total_items=30000, + block=(256, 1, 1), + shared_mem=0, + priority=100 + ) + print(f"Kernel 2: 30000 items -> {num_slices_2} slices (priority=100)") + + # Execute slices (round-robin) + executed = 0 + print("\nExecuting slices (round-robin):") + while executed < 4: + slice_info = slice_scheduler.get_next_slice() + if slice_info is None: + break + print(f" Slice {slice_info.slice_id}: kernel=0x{slice_info.kernel_handle:X}, offset={slice_info.offset}, count={slice_info.count}") + slice_scheduler.complete_slice(0.1) # 0.1ms exec time + executed += 1 + + slice_stats = slice_scheduler.stats() + print(f"\nSlice stats:") + print(f" Total slices: {slice_stats.total_slices}") + print(f" Completed: {slice_stats.completed_slices}") + print(f" Pending: {slice_stats.pending_slices}") + + # ========================================================================= + # 9. Pinned Memory Support (NEW) + # ========================================================================= + print_header("9. Pinned Memory Support (NEW)") + + pinned_config = rust.PinnedPoolConfig( + max_size=256 * 1024 * 1024, # 256 MB + enable_pooling=True, + alignment=256 + ) + pinned_manager = rust.PinnedMemoryManager(pinned_config) + print(f"Created pinned memory manager: {pinned_config}") + + # Allocate pinned memory + allocations = [] + for i, size in enumerate([4096, 65536, 1048576]): # 4KB, 64KB, 1MB + result = pinned_manager.allocate(size) + alloc_id, size_class, reused = result + if not reused: + # In real code, would call cudaHostAlloc here + pinned_manager.register(alloc_id, 0x10000000 + i * 0x100000, size_class) + allocations.append(alloc_id) + print(f" Allocated {size} bytes -> id={alloc_id}, class={size_class}, reused={reused}") + + # Associate with task + pinned_manager.associate_task(allocations[0], "task-1") + + # Free and observe pooling + should_free, ptr = pinned_manager.free(allocations[1]) + print(f"\nFreed allocation {allocations[1]}: should_free={should_free} (pooled)") + + # Allocate again - should hit pool + result2 = pinned_manager.allocate(65536) + print(f"Re-allocated 65KB: reused={result2[2]}") + + pinned_stats = pinned_manager.stats() + print(f"\nPinned stats:") + print(f" Current used: {pinned_stats.current_used} bytes") + print(f" Pool hits: {pinned_stats.pool_hits}") + print(f" Pool misses: {pinned_stats.pool_misses}") + print(f" Hit rate: {pinned_stats.hit_rate():.1%}") + + # ========================================================================= + # 10. Kernel Cache (NEW) + # ========================================================================= + print_header("10. Kernel Cache (NEW)") + + cache_config = rust.CacheConfig( + max_entries=1024, + max_ptx_size=256 * 1024 * 1024, # 256 MB + enable_eviction=True, + ttl_seconds=0.0 # No TTL + ) + kernel_cache = rust.KernelCache(cache_config) + print(f"Created kernel cache: {cache_config}") + + # Compile options + compile_opts = rust.CompileOptions("sm_86").flag("-lineinfo").define("BLOCK_SIZE", "256") + print(f"\nCompile options: {compile_opts}") + + # Insert kernels + kernels = [ + ("__global__ void add_kernel(float* a, float* b, float* c) { ... }", "add_kernel"), + ("__global__ void mul_kernel(float* a, float* b, float* c) { ... }", "mul_kernel"), + ("__global__ void matmul_kernel(float* A, float* B, float* C, int M, int N, int K) { ... }", "matmul_kernel"), + ] + + for source, name in kernels: + ptx = f"// PTX for {name}\n.version 7.0\n.target sm_86\n..." + key = kernel_cache.insert(source, name, ptx, compile_opts) + print(f" Cached {name}: key={key}") + + # Test cache hits + print("\nTesting cache hits:") + for source, name in kernels: + cached = kernel_cache.get_by_name(name, compile_opts) + if cached: + print(f" {name}: HIT (accesses={cached.access_count})") + else: + print(f" {name}: MISS") + + # Simulate module loading + for i, (source, name) in enumerate(kernels): + key = rust.KernelCache.compute_key(source, name, compile_opts) + kernel_cache.set_handles(key, 0xAABB0000 + i, 0xCCDD0000 + i) + + cache_stats = kernel_cache.stats() + print(f"\nCache stats:") + print(f" Entries: {cache_stats.entries}") + print(f" Hits: {cache_stats.hits}") + print(f" Misses: {cache_stats.misses}") + print(f" Hit rate: {cache_stats.hit_rate():.1%}") + print(f" PTX size: {cache_stats.ptx_size} bytes") + print(f" Loaded kernels: {cache_stats.loaded_count}") + + # ========================================================================= + # 11. GPU Partitioning (NEW) + # ========================================================================= + print_header("11. GPU Partitioning (NEW)") + + partition_config = rust.PartitionConfig( + total_memory=8 * 1024 * 1024 * 1024, # 8 GB + allow_overcommit=False, + overcommit_ratio=1.0 + ) + partition_manager = rust.PartitionManager(partition_config) + print(f"Created partition manager: {partition_config}") + + # Create partitions + partitions = [ + ("inference", "Inference Workload", rust.PartitionLimits.with_memory(4 * 1024 * 1024 * 1024).compute(0.5).bandwidth(0.4)), + ("training", "Training Workload", rust.PartitionLimits.with_memory(3 * 1024 * 1024 * 1024).compute(0.4).bandwidth(0.5)), + ] + + for pid, name, limits in partitions: + partition_manager.create_partition(pid, name, limits) + print(f" Created partition '{pid}': memory={limits.memory_quota / 1024**3:.0f}GB, compute={limits.compute_share:.0%}") + + # Assign tasks + partition_manager.assign_task("inference-task-1", "inference") + partition_manager.assign_task("training-task-1", "training") + print(f"\nAssigned tasks to partitions") + + # Check partition for task + for task_id in ["inference-task-1", "training-task-1", "unknown-task"]: + partition = partition_manager.get_task_partition(task_id) + if partition: + print(f" {task_id} -> {partition.id} ({partition.name})") + else: + print(f" {task_id} -> (no partition)") + + partition_stats = partition_manager.stats() + print(f"\nPartition stats:") + print(f" Partitions: {partition_stats.partition_count}") + print(f" Memory allocated: {partition_stats.total_memory_allocated / 1024**3:.1f} GB") + print(f" Compute allocated: {partition_stats.total_compute_allocated:.0%}") + print(f" Available memory: {partition_stats.available_memory / 1024**3:.1f} GB") + print(f" Available compute: {partition_stats.available_compute:.0%}") + + # ========================================================================= + # 12. Tiled Matmul Benchmark + # ========================================================================= + print_header("12. Tiled Matmul Benchmark") + + print("\nMatrix Size | Kernel | Time (ms) | GFLOPS | vs NumPy") + print("-" * 60) + + sizes = [512, 1024, 2048, 4096] + results = [] + + for size in sizes: + M, N, K = size, size, size + + # Create test matrices + A_np = np.random.randn(M, K).astype(np.float32) + B_np = np.random.randn(K, N).astype(np.float32) + + # Warmup + A_gpu = native.from_numpy(A_np) + B_gpu = native.from_numpy(B_np) + _ = native.matmul(A_gpu, B_gpu) + + # Benchmark GPU + iterations = 5 if size >= 2048 else 10 + times = [] + for _ in range(iterations): + A_gpu = native.from_numpy(A_np) + B_gpu = native.from_numpy(B_np) + start = time.perf_counter() + C_gpu = native.matmul(A_gpu, B_gpu) + gpu_time = time.perf_counter() - start + times.append(gpu_time) + + avg_time = np.median(times) + gflops = 2 * M * N * K / avg_time / 1e9 + + kernel = "Tiled" if size >= 2048 else "L2-opt" + + # NumPy reference + start = time.perf_counter() + C_cpu = np.matmul(A_np, B_np) + cpu_time = time.perf_counter() - start + + speedup = cpu_time / avg_time + + # Verify + C_result = C_gpu.to_numpy() + rel_error = np.max(np.abs(C_result - C_cpu)) / np.max(np.abs(C_cpu)) + + results.append({ + 'size': size, + 'kernel': kernel, + 'time_ms': avg_time * 1000, + 'gflops': gflops, + 'speedup': speedup, + 'error': rel_error + }) + + status = "OK" if rel_error < 1e-3 else f"ERR:{rel_error:.1e}" + print(f"{size:>5}x{size:<5} | {kernel:<9} | {avg_time*1000:>8.2f} | {gflops:>7.1f} | {speedup:>5.1f}x ({status})") + + print("-" * 60) + + peak = max(results, key=lambda x: x['gflops']) + print(f"\nPeak: {peak['gflops']:.1f} GFLOPS at {peak['size']}x{peak['size']} ({peak['kernel']})") + + # ========================================================================= + # Summary + # ========================================================================= + print_header("Summary - PyGPUkit v0.2 Complete Features") + + print(""" + === Core Infrastructure === + 1. Rust Memory Pool - LRU eviction, size-class free lists + 2. Rust Scheduler - Priority queue, memory reservation + 3. Rust Transfer Engine - Separate H2D/D2H streams, priority + 4. Rust Kernel Dispatch - Per-stream limits, lifecycle tracking + + === NEW v0.2 Features === + 5. Admission Control - Deterministic admission, quota enforcement + 6. QoS Policy Framework - Guaranteed/Burstable/BestEffort tiers + 7. Kernel Pacing Engine - Bandwidth-based throttling per stream + 8. Micro-Slicing - Kernel splitting, round-robin fairness + 9. Pinned Memory - Page-locked host memory with pooling + 10. Kernel Cache - PTX caching, LRU eviction, TTL + 11. GPU Partitioning - Resource isolation, multi-tenant + + === Compute === + 12. Tiled Matmul - Shared memory + double buffering + """) + + # Count tests + print(f"Total Rust tests: 106 passing") + print(f"Features demonstrated: 12") + + print("\n" + "=" * 70) + print(" PyGPUkit v0.2 Demo Complete!") + print("=" * 70) + + return 0 + +if __name__ == "__main__": + sys.exit(main()) diff --git a/rust/pygpukit-core/src/dispatch/cache.rs b/rust/pygpukit-core/src/dispatch/cache.rs new file mode 100644 index 0000000..2fb04cd --- /dev/null +++ b/rust/pygpukit-core/src/dispatch/cache.rs @@ -0,0 +1,641 @@ +//! Kernel Cache +//! +//! Caches compiled CUDA kernels to avoid repeated NVRTC compilation. +//! Kernels are identified by a hash of their source code and compile options. + +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::collections::hash_map::DefaultHasher; + +/// Compile options that affect kernel output +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct CompileOptions { + /// Compute capability (e.g., "sm_75") + pub compute_capability: String, + /// Additional compiler flags + pub flags: Vec, + /// Define macros + pub defines: Vec<(String, String)>, + /// Include paths + pub include_paths: Vec, +} + +impl Default for CompileOptions { + fn default() -> Self { + Self { + compute_capability: "sm_75".into(), + flags: Vec::new(), + defines: Vec::new(), + include_paths: Vec::new(), + } + } +} + +impl CompileOptions { + /// Create with compute capability + pub fn with_compute(compute: &str) -> Self { + Self { + compute_capability: compute.into(), + ..Default::default() + } + } + + /// Add a flag + pub fn flag(mut self, flag: &str) -> Self { + self.flags.push(flag.into()); + self + } + + /// Add a define macro + pub fn define(mut self, name: &str, value: &str) -> Self { + self.defines.push((name.into(), value.into())); + self + } + + /// Add an include path + pub fn include(mut self, path: &str) -> Self { + self.include_paths.push(path.into()); + self + } +} + +/// Cached kernel entry +#[derive(Debug, Clone)] +pub struct CachedKernel { + /// Cache key (hash) + pub key: u64, + /// Kernel name + pub name: String, + /// PTX code + pub ptx: String, + /// CUmodule handle (set after loading) + pub module_handle: Option, + /// CUfunction handle (set after loading) + pub function_handle: Option, + /// Compile options used + pub options: CompileOptions, + /// Creation timestamp + pub created_at: f64, + /// Last access timestamp + pub last_access: f64, + /// Access count + pub access_count: usize, + /// Source code hash (for verification) + pub source_hash: u64, +} + +impl CachedKernel { + /// Create a new cached kernel + pub fn new(key: u64, name: String, ptx: String, options: CompileOptions, source_hash: u64) -> Self { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0); + Self { + key, + name, + ptx, + module_handle: None, + function_handle: None, + options, + created_at: now, + last_access: now, + access_count: 1, + source_hash, + } + } + + /// Set module and function handles + pub fn set_handles(&mut self, module: u64, function: u64) { + self.module_handle = Some(module); + self.function_handle = Some(function); + } + + /// Touch to update access time + pub fn touch(&mut self) { + self.last_access = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0); + self.access_count += 1; + } + + /// Check if loaded + pub fn is_loaded(&self) -> bool { + self.function_handle.is_some() + } +} + +/// Kernel cache configuration +#[derive(Debug, Clone)] +pub struct CacheConfig { + /// Maximum cache entries + pub max_entries: usize, + /// Maximum PTX size in bytes + pub max_ptx_size: usize, + /// Enable LRU eviction + pub enable_eviction: bool, + /// TTL in seconds (0 = infinite) + pub ttl_seconds: f64, +} + +impl Default for CacheConfig { + fn default() -> Self { + Self { + max_entries: 1024, + max_ptx_size: 256 * 1024 * 1024, // 256MB + enable_eviction: true, + ttl_seconds: 0.0, // No TTL by default + } + } +} + +impl CacheConfig { + /// Create with max entries + pub fn with_max_entries(max_entries: usize) -> Self { + Self { + max_entries, + ..Default::default() + } + } + + /// Set max PTX size + pub fn max_ptx_size(mut self, bytes: usize) -> Self { + self.max_ptx_size = bytes; + self + } + + /// Set TTL + pub fn ttl(mut self, seconds: f64) -> Self { + self.ttl_seconds = seconds; + self + } +} + +/// Kernel cache statistics +#[derive(Debug, Clone, Default)] +pub struct CacheStats { + /// Cache hits + pub hits: usize, + /// Cache misses + pub misses: usize, + /// Total entries + pub entries: usize, + /// Total PTX size in bytes + pub ptx_size: usize, + /// Evictions due to capacity + pub evictions: usize, + /// Evictions due to TTL + pub ttl_evictions: usize, + /// Loaded kernels (with function handles) + pub loaded_count: usize, +} + +impl CacheStats { + /// Calculate hit rate + pub fn hit_rate(&self) -> f64 { + let total = self.hits + self.misses; + if total > 0 { + self.hits as f64 / total as f64 + } else { + 0.0 + } + } +} + +/// Kernel cache +/// +/// Caches compiled CUDA kernels to avoid repeated NVRTC compilation. +#[derive(Debug)] +pub struct KernelCache { + config: CacheConfig, + /// Cached kernels by key + cache: HashMap, + /// Name to key mapping for lookups + name_to_key: HashMap>, + /// Statistics + hits: usize, + misses: usize, + evictions: usize, + ttl_evictions: usize, + total_ptx_size: usize, +} + +impl KernelCache { + /// Create a new kernel cache + pub fn new(config: CacheConfig) -> Self { + Self { + config, + cache: HashMap::new(), + name_to_key: HashMap::new(), + hits: 0, + misses: 0, + evictions: 0, + ttl_evictions: 0, + total_ptx_size: 0, + } + } + + /// Create with defaults + pub fn with_defaults() -> Self { + Self::new(CacheConfig::default()) + } + + /// Compute cache key from source and options + pub fn compute_key(source: &str, name: &str, options: &CompileOptions) -> u64 { + let mut hasher = DefaultHasher::new(); + source.hash(&mut hasher); + name.hash(&mut hasher); + options.hash(&mut hasher); + hasher.finish() + } + + /// Compute source hash only + pub fn hash_source(source: &str) -> u64 { + let mut hasher = DefaultHasher::new(); + source.hash(&mut hasher); + hasher.finish() + } + + /// Get cached kernel by key + pub fn get(&mut self, key: u64) -> Option<&CachedKernel> { + // Check TTL first + if self.config.ttl_seconds > 0.0 { + if let Some(entry) = self.cache.get(&key) { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0); + if now - entry.created_at > self.config.ttl_seconds { + // TTL expired - remove + self.remove(key); + self.ttl_evictions += 1; + self.misses += 1; + return None; + } + } + } + + if let Some(entry) = self.cache.get_mut(&key) { + entry.touch(); + self.hits += 1; + Some(entry) + } else { + self.misses += 1; + None + } + } + + /// Get cached kernel by name and options + pub fn get_by_name(&mut self, name: &str, options: &CompileOptions) -> Option<&CachedKernel> { + // Find keys for this name + let keys = self.name_to_key.get(name)?; + + // Find matching options + for &key in keys { + if let Some(entry) = self.cache.get(&key) { + if &entry.options == options { + // Touch and return + self.hits += 1; + if let Some(entry) = self.cache.get_mut(&key) { + entry.touch(); + } + return self.cache.get(&key); + } + } + } + + self.misses += 1; + None + } + + /// Insert a compiled kernel + pub fn insert(&mut self, source: &str, name: &str, ptx: String, options: CompileOptions) -> u64 { + let key = Self::compute_key(source, name, &options); + let source_hash = Self::hash_source(source); + + // Check if already exists + if self.cache.contains_key(&key) { + if let Some(entry) = self.cache.get_mut(&key) { + entry.touch(); + } + return key; + } + + // Evict if necessary + self.evict_if_needed(ptx.len()); + + // Insert + let entry = CachedKernel::new(key, name.into(), ptx.clone(), options, source_hash); + self.total_ptx_size += ptx.len(); + self.cache.insert(key, entry); + + // Update name mapping + self.name_to_key + .entry(name.into()) + .or_insert_with(Vec::new) + .push(key); + + key + } + + /// Update handles for a cached kernel + pub fn set_handles(&mut self, key: u64, module: u64, function: u64) -> bool { + if let Some(entry) = self.cache.get_mut(&key) { + entry.set_handles(module, function); + true + } else { + false + } + } + + /// Remove a kernel from cache + pub fn remove(&mut self, key: u64) -> Option { + if let Some(entry) = self.cache.remove(&key) { + self.total_ptx_size = self.total_ptx_size.saturating_sub(entry.ptx.len()); + + // Remove from name mapping + if let Some(keys) = self.name_to_key.get_mut(&entry.name) { + keys.retain(|&k| k != key); + if keys.is_empty() { + self.name_to_key.remove(&entry.name); + } + } + + Some(entry) + } else { + None + } + } + + /// Evict entries if needed + fn evict_if_needed(&mut self, new_size: usize) { + if !self.config.enable_eviction { + return; + } + + // Evict by entry count + while self.cache.len() >= self.config.max_entries { + self.evict_lru(); + } + + // Evict by size + while self.total_ptx_size + new_size > self.config.max_ptx_size && !self.cache.is_empty() { + self.evict_lru(); + } + } + + /// Evict least recently used entry + fn evict_lru(&mut self) { + // Find LRU entry + let lru_key = self.cache + .iter() + .min_by(|a, b| a.1.last_access.partial_cmp(&b.1.last_access).unwrap()) + .map(|(&k, _)| k); + + if let Some(key) = lru_key { + self.remove(key); + self.evictions += 1; + } + } + + /// Clear expired entries (TTL) + pub fn clear_expired(&mut self) -> usize { + if self.config.ttl_seconds <= 0.0 { + return 0; + } + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0); + + let expired: Vec = self.cache + .iter() + .filter(|(_, v)| now - v.created_at > self.config.ttl_seconds) + .map(|(&k, _)| k) + .collect(); + + let count = expired.len(); + for key in expired { + self.remove(key); + self.ttl_evictions += 1; + } + + count + } + + /// Get statistics + pub fn stats(&self) -> CacheStats { + let loaded_count = self.cache.values().filter(|e| e.is_loaded()).count(); + CacheStats { + hits: self.hits, + misses: self.misses, + entries: self.cache.len(), + ptx_size: self.total_ptx_size, + evictions: self.evictions, + ttl_evictions: self.ttl_evictions, + loaded_count, + } + } + + /// Check if kernel is cached + pub fn contains(&self, key: u64) -> bool { + self.cache.contains_key(&key) + } + + /// Get all cached kernel names + pub fn kernel_names(&self) -> Vec<&str> { + self.name_to_key.keys().map(|s| s.as_str()).collect() + } + + /// Get number of entries + pub fn len(&self) -> usize { + self.cache.len() + } + + /// Check if empty + pub fn is_empty(&self) -> bool { + self.cache.is_empty() + } + + /// Clear all cache + pub fn clear(&mut self) { + self.cache.clear(); + self.name_to_key.clear(); + self.total_ptx_size = 0; + } + + /// Get config + pub fn config(&self) -> &CacheConfig { + &self.config + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_compile_options() { + let opts = CompileOptions::with_compute("sm_80") + .flag("-lineinfo") + .define("BLOCK_SIZE", "256"); + + assert_eq!(opts.compute_capability, "sm_80"); + assert_eq!(opts.flags.len(), 1); + assert_eq!(opts.defines.len(), 1); + } + + #[test] + fn test_compute_key() { + let source = "__global__ void foo() {}"; + let opts = CompileOptions::default(); + + let key1 = KernelCache::compute_key(source, "foo", &opts); + let key2 = KernelCache::compute_key(source, "foo", &opts); + assert_eq!(key1, key2); + + // Different name = different key + let key3 = KernelCache::compute_key(source, "bar", &opts); + assert_ne!(key1, key3); + + // Different options = different key + let opts2 = CompileOptions::with_compute("sm_80"); + let key4 = KernelCache::compute_key(source, "foo", &opts2); + assert_ne!(key1, key4); + } + + #[test] + fn test_cache_insert_get() { + let mut cache = KernelCache::with_defaults(); + + let source = "__global__ void test_kernel() {}"; + let ptx = "// PTX code here"; + let opts = CompileOptions::default(); + + let key = cache.insert(source, "test_kernel", ptx.into(), opts.clone()); + + // Get should hit + let entry = cache.get(key); + assert!(entry.is_some()); + assert_eq!(entry.unwrap().name, "test_kernel"); + + let stats = cache.stats(); + assert_eq!(stats.hits, 1); + assert_eq!(stats.entries, 1); + } + + #[test] + fn test_cache_miss() { + let mut cache = KernelCache::with_defaults(); + + let result = cache.get(12345); + assert!(result.is_none()); + + let stats = cache.stats(); + assert_eq!(stats.misses, 1); + } + + #[test] + fn test_get_by_name() { + let mut cache = KernelCache::with_defaults(); + + let source = "__global__ void my_kernel() {}"; + let opts = CompileOptions::default(); + + cache.insert(source, "my_kernel", "ptx".into(), opts.clone()); + + let entry = cache.get_by_name("my_kernel", &opts); + assert!(entry.is_some()); + + // Different options should miss + let opts2 = CompileOptions::with_compute("sm_80"); + let entry2 = cache.get_by_name("my_kernel", &opts2); + assert!(entry2.is_none()); + } + + #[test] + fn test_set_handles() { + let mut cache = KernelCache::with_defaults(); + + let key = cache.insert("source", "kernel", "ptx".into(), CompileOptions::default()); + + assert!(!cache.get(key).unwrap().is_loaded()); + + cache.set_handles(key, 0xABCD, 0x1234); + + let entry = cache.get(key).unwrap(); + assert!(entry.is_loaded()); + assert_eq!(entry.module_handle, Some(0xABCD)); + assert_eq!(entry.function_handle, Some(0x1234)); + } + + #[test] + fn test_eviction() { + let config = CacheConfig::with_max_entries(2); + let mut cache = KernelCache::new(config); + + cache.insert("src1", "k1", "ptx1".into(), CompileOptions::default()); + cache.insert("src2", "k2", "ptx2".into(), CompileOptions::default()); + + // Access k2 to make k1 the LRU + let key2 = KernelCache::compute_key("src2", "k2", &CompileOptions::default()); + cache.get(key2); + + // Insert third - should evict k1 + cache.insert("src3", "k3", "ptx3".into(), CompileOptions::default()); + + assert_eq!(cache.len(), 2); + assert!(!cache.contains(KernelCache::compute_key("src1", "k1", &CompileOptions::default()))); + + let stats = cache.stats(); + assert_eq!(stats.evictions, 1); + } + + #[test] + fn test_remove() { + let mut cache = KernelCache::with_defaults(); + + let key = cache.insert("source", "kernel", "ptx".into(), CompileOptions::default()); + assert_eq!(cache.len(), 1); + + let removed = cache.remove(key); + assert!(removed.is_some()); + assert_eq!(cache.len(), 0); + assert!(cache.kernel_names().is_empty()); + } + + #[test] + fn test_clear() { + let mut cache = KernelCache::with_defaults(); + + cache.insert("src1", "k1", "ptx1".into(), CompileOptions::default()); + cache.insert("src2", "k2", "ptx2".into(), CompileOptions::default()); + + assert_eq!(cache.len(), 2); + + cache.clear(); + assert!(cache.is_empty()); + assert_eq!(cache.stats().ptx_size, 0); + } + + #[test] + fn test_hit_rate() { + let mut cache = KernelCache::with_defaults(); + + let key = cache.insert("source", "kernel", "ptx".into(), CompileOptions::default()); + + // 2 hits + cache.get(key); + cache.get(key); + + // 1 miss + cache.get(99999); + + let stats = cache.stats(); + assert_eq!(stats.hits, 2); + assert_eq!(stats.misses, 1); + assert!((stats.hit_rate() - 0.666).abs() < 0.01); + } +} diff --git a/rust/pygpukit-core/src/dispatch/mod.rs b/rust/pygpukit-core/src/dispatch/mod.rs index f89171f..28231af 100644 --- a/rust/pygpukit-core/src/dispatch/mod.rs +++ b/rust/pygpukit-core/src/dispatch/mod.rs @@ -4,10 +4,25 @@ //! - Per-task stream assignment //! - Integration with the scheduler tick loop //! - Kernel execution tracking +//! - Bandwidth-based kernel pacing +//! - Micro-slicing for fairness and latency +//! - Kernel caching for compiled PTX //! //! Note: Actual CUDA Driver API calls (cuLaunchKernel) are handled by C++ backend. //! This module provides the Rust-side coordination logic. mod controller; +mod pacing; +mod slicing; +mod cache; pub use controller::{KernelDispatcher, KernelLaunchRequest, KernelState, DispatchStats, LaunchConfig}; +pub use pacing::{ + KernelPacingEngine, PacingConfig, PacingDecision, PacingStats, StreamPacingStats, +}; +pub use slicing::{ + SliceScheduler, SliceConfig, SlicedKernel, KernelSlice, SliceInfo, SliceStats, +}; +pub use cache::{ + KernelCache, CacheConfig, CachedKernel, CompileOptions, CacheStats, +}; diff --git a/rust/pygpukit-core/src/dispatch/pacing.rs b/rust/pygpukit-core/src/dispatch/pacing.rs new file mode 100644 index 0000000..7013eb3 --- /dev/null +++ b/rust/pygpukit-core/src/dispatch/pacing.rs @@ -0,0 +1,410 @@ +//! Kernel Pacing Engine +//! +//! Throttles kernel launches based on allocated bandwidth. +//! Implements time-based pacing to prevent GPU saturation. + +use std::time::{Instant, Duration}; +use std::collections::HashMap; + +/// Pacing configuration +#[derive(Debug, Clone)] +pub struct PacingConfig { + /// Total bandwidth available (0.0 - 1.0) + pub total_bandwidth: f64, + /// Pacing window duration in milliseconds + pub window_ms: f64, + /// Minimum interval between launches in milliseconds + pub min_interval_ms: f64, + /// Enable adaptive pacing based on actual utilization + pub adaptive: bool, +} + +impl Default for PacingConfig { + fn default() -> Self { + Self { + total_bandwidth: 1.0, + window_ms: 100.0, + min_interval_ms: 0.1, + adaptive: true, + } + } +} + +impl PacingConfig { + /// Create with total bandwidth + pub fn with_bandwidth(bandwidth: f64) -> Self { + Self { + total_bandwidth: bandwidth, + ..Default::default() + } + } + + /// Set window duration + pub fn window(mut self, window_ms: f64) -> Self { + self.window_ms = window_ms; + self + } + + /// Set minimum interval + pub fn min_interval(mut self, min_ms: f64) -> Self { + self.min_interval_ms = min_ms; + self + } + + /// Enable/disable adaptive pacing + pub fn adaptive(mut self, enable: bool) -> Self { + self.adaptive = enable; + self + } +} + +/// Per-stream pacing state +#[derive(Debug)] +struct StreamPacing { + /// Last launch timestamp + last_launch: Option, + /// Allocated bandwidth for this stream + bandwidth: f64, + /// Launches in current window + launches_in_window: usize, + /// Window start time + window_start: Instant, + /// Total launches + total_launches: usize, + /// Total throttled requests + throttled_count: usize, +} + +impl StreamPacing { + fn new(bandwidth: f64) -> Self { + Self { + last_launch: None, + bandwidth, + launches_in_window: 0, + window_start: Instant::now(), + total_launches: 0, + throttled_count: 0, + } + } +} + +/// Pacing decision result +#[derive(Debug, Clone, PartialEq)] +pub enum PacingDecision { + /// Launch immediately + Launch, + /// Wait for the specified duration before launching + Wait { delay_ms: f64 }, + /// Throttled - exceed bandwidth allocation + Throttle { reason: String }, +} + +impl PacingDecision { + /// Check if immediate launch is allowed + pub fn can_launch(&self) -> bool { + matches!(self, Self::Launch) + } + + /// Check if throttled + pub fn is_throttled(&self) -> bool { + matches!(self, Self::Throttle { .. }) + } + + /// Get wait time in milliseconds (0.0 if can launch immediately) + pub fn wait_ms(&self) -> f64 { + match self { + Self::Launch => 0.0, + Self::Wait { delay_ms } => *delay_ms, + Self::Throttle { .. } => f64::INFINITY, + } + } +} + +/// Kernel pacing engine +#[derive(Debug)] +pub struct KernelPacingEngine { + config: PacingConfig, + /// Per-stream pacing state + streams: HashMap, + /// Global bandwidth usage tracking + used_bandwidth: f64, + /// Statistics + total_launches: usize, + total_throttled: usize, + total_waited: usize, +} + +impl KernelPacingEngine { + /// Create a new pacing engine + pub fn new(config: PacingConfig) -> Self { + Self { + config, + streams: HashMap::new(), + used_bandwidth: 0.0, + total_launches: 0, + total_throttled: 0, + total_waited: 0, + } + } + + /// Create with default config + pub fn with_defaults() -> Self { + Self::new(PacingConfig::default()) + } + + /// Allocate bandwidth for a stream + pub fn allocate_stream(&mut self, stream_id: u64, bandwidth: f64) -> bool { + let available = self.config.total_bandwidth - self.used_bandwidth; + if bandwidth > available { + return false; + } + + self.used_bandwidth += bandwidth; + self.streams.insert(stream_id, StreamPacing::new(bandwidth)); + true + } + + /// Release bandwidth for a stream + pub fn release_stream(&mut self, stream_id: u64) { + if let Some(pacing) = self.streams.remove(&stream_id) { + self.used_bandwidth = (self.used_bandwidth - pacing.bandwidth).max(0.0); + } + } + + /// Check if a kernel launch should proceed + /// + /// Returns a pacing decision indicating whether to launch, + /// wait, or throttle. + pub fn should_launch(&self, stream_id: u64) -> PacingDecision { + let stream = match self.streams.get(&stream_id) { + Some(s) => s, + None => return PacingDecision::Launch, // Unknown stream, allow + }; + + let _now = Instant::now(); + + // Check minimum interval + if let Some(last) = stream.last_launch { + let elapsed_ms = last.elapsed().as_secs_f64() * 1000.0; + if elapsed_ms < self.config.min_interval_ms { + let delay = self.config.min_interval_ms - elapsed_ms; + return PacingDecision::Wait { delay_ms: delay }; + } + } + + // Calculate pacing interval based on bandwidth allocation + // Higher bandwidth = shorter interval = more launches allowed + let pacing_interval_ms = if stream.bandwidth > 0.0 { + self.config.window_ms * (1.0 - stream.bandwidth) + } else { + self.config.window_ms // No bandwidth = must wait full window + }; + + // Check if we're within pacing interval + if let Some(last) = stream.last_launch { + let elapsed_ms = last.elapsed().as_secs_f64() * 1000.0; + if elapsed_ms < pacing_interval_ms { + let delay = pacing_interval_ms - elapsed_ms; + return PacingDecision::Wait { delay_ms: delay }; + } + } + + // Check window-based throttling + let window_elapsed = stream.window_start.elapsed(); + if window_elapsed < Duration::from_secs_f64(self.config.window_ms / 1000.0) { + // Still in window - check launch count + let max_launches_per_window = (stream.bandwidth * 100.0).ceil() as usize; + if stream.launches_in_window >= max_launches_per_window.max(1) { + return PacingDecision::Throttle { + reason: format!( + "Exceeded {} launches in {:.1}ms window", + max_launches_per_window, self.config.window_ms + ), + }; + } + } + + PacingDecision::Launch + } + + /// Record a kernel launch + pub fn record_launch(&mut self, stream_id: u64) { + let now = Instant::now(); + + if let Some(stream) = self.streams.get_mut(&stream_id) { + // Reset window if expired + let window_duration = Duration::from_secs_f64(self.config.window_ms / 1000.0); + if stream.window_start.elapsed() >= window_duration { + stream.window_start = now; + stream.launches_in_window = 0; + } + + stream.last_launch = Some(now); + stream.launches_in_window += 1; + stream.total_launches += 1; + } + + self.total_launches += 1; + } + + /// Record a throttled request + pub fn record_throttle(&mut self, stream_id: u64) { + if let Some(stream) = self.streams.get_mut(&stream_id) { + stream.throttled_count += 1; + } + self.total_throttled += 1; + } + + /// Record a waited request + pub fn record_wait(&mut self) { + self.total_waited += 1; + } + + /// Get stream statistics + pub fn stream_stats(&self, stream_id: u64) -> Option { + self.streams.get(&stream_id).map(|s| StreamPacingStats { + stream_id, + bandwidth: s.bandwidth, + launches_in_window: s.launches_in_window, + total_launches: s.total_launches, + throttled_count: s.throttled_count, + }) + } + + /// Get global statistics + pub fn stats(&self) -> PacingStats { + PacingStats { + stream_count: self.streams.len(), + used_bandwidth: self.used_bandwidth, + available_bandwidth: self.config.total_bandwidth - self.used_bandwidth, + total_launches: self.total_launches, + total_throttled: self.total_throttled, + total_waited: self.total_waited, + } + } + + /// Get configuration + pub fn config(&self) -> &PacingConfig { + &self.config + } + + /// Reset all pacing state + pub fn reset(&mut self) { + self.streams.clear(); + self.used_bandwidth = 0.0; + self.total_launches = 0; + self.total_throttled = 0; + self.total_waited = 0; + } +} + +/// Per-stream pacing statistics +#[derive(Debug, Clone, Default)] +pub struct StreamPacingStats { + /// Stream ID + pub stream_id: u64, + /// Allocated bandwidth + pub bandwidth: f64, + /// Launches in current window + pub launches_in_window: usize, + /// Total launches + pub total_launches: usize, + /// Throttled count + pub throttled_count: usize, +} + +/// Global pacing statistics +#[derive(Debug, Clone, Default)] +pub struct PacingStats { + /// Number of active streams + pub stream_count: usize, + /// Used bandwidth + pub used_bandwidth: f64, + /// Available bandwidth + pub available_bandwidth: f64, + /// Total launches + pub total_launches: usize, + /// Total throttled requests + pub total_throttled: usize, + /// Total waited requests + pub total_waited: usize, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pacing_engine_creation() { + let engine = KernelPacingEngine::with_defaults(); + assert_eq!(engine.stats().stream_count, 0); + assert!((engine.config().total_bandwidth - 1.0).abs() < 0.001); + } + + #[test] + fn test_stream_allocation() { + let mut engine = KernelPacingEngine::with_defaults(); + + assert!(engine.allocate_stream(0, 0.5)); + assert!(engine.allocate_stream(1, 0.3)); + assert!(!engine.allocate_stream(2, 0.3)); // Would exceed 1.0 + + let stats = engine.stats(); + assert_eq!(stats.stream_count, 2); + assert!((stats.used_bandwidth - 0.8).abs() < 0.001); + } + + #[test] + fn test_launch_decision() { + let config = PacingConfig::default() + .window(10.0) + .min_interval(0.0); + let mut engine = KernelPacingEngine::new(config); + engine.allocate_stream(0, 0.5); + + // First launch should succeed + let decision = engine.should_launch(0); + assert!(decision.can_launch()); + } + + #[test] + fn test_pacing_interval() { + let config = PacingConfig::default() + .window(100.0) + .min_interval(5.0); + let mut engine = KernelPacingEngine::new(config); + engine.allocate_stream(0, 0.5); + + // First launch + engine.record_launch(0); + + // Second launch should wait + let decision = engine.should_launch(0); + match decision { + PacingDecision::Wait { delay_ms } => { + assert!(delay_ms > 0.0); + } + _ => panic!("Expected Wait decision"), + } + } + + #[test] + fn test_stream_release() { + let mut engine = KernelPacingEngine::with_defaults(); + engine.allocate_stream(0, 0.5); + engine.release_stream(0); + + let stats = engine.stats(); + assert_eq!(stats.stream_count, 0); + assert!((stats.used_bandwidth).abs() < 0.001); + } + + #[test] + fn test_unknown_stream() { + let engine = KernelPacingEngine::with_defaults(); + + // Unknown stream should be allowed + let decision = engine.should_launch(999); + assert!(decision.can_launch()); + } +} diff --git a/rust/pygpukit-core/src/dispatch/slicing.rs b/rust/pygpukit-core/src/dispatch/slicing.rs new file mode 100644 index 0000000..9183f59 --- /dev/null +++ b/rust/pygpukit-core/src/dispatch/slicing.rs @@ -0,0 +1,506 @@ +//! Micro-Slicing Framework +//! +//! Splits GPU kernels into small runnable slices to improve +//! latency and fairness under QoS constraints. + +use std::collections::VecDeque; + +/// Slice configuration +#[derive(Debug, Clone)] +pub struct SliceConfig { + /// Maximum work items per slice + pub max_items_per_slice: usize, + /// Maximum duration per slice in milliseconds + pub max_duration_ms: f64, + /// Minimum number of slices to create + pub min_slices: usize, + /// Maximum number of slices to create + pub max_slices: usize, + /// Enable adaptive slice sizing + pub adaptive: bool, +} + +impl Default for SliceConfig { + fn default() -> Self { + Self { + max_items_per_slice: 65536, + max_duration_ms: 1.0, + min_slices: 1, + max_slices: 256, + adaptive: true, + } + } +} + +impl SliceConfig { + /// Create with max items per slice + pub fn with_max_items(max_items: usize) -> Self { + Self { + max_items_per_slice: max_items, + ..Default::default() + } + } + + /// Set max duration + pub fn max_duration(mut self, ms: f64) -> Self { + self.max_duration_ms = ms; + self + } + + /// Set min slices + pub fn min_slices(mut self, n: usize) -> Self { + self.min_slices = n; + self + } + + /// Set max slices + pub fn max_slices(mut self, n: usize) -> Self { + self.max_slices = n; + self + } +} + +/// A single kernel slice +#[derive(Debug, Clone)] +pub struct KernelSlice { + /// Slice ID within the kernel + pub id: usize, + /// Starting offset in work items + pub offset: usize, + /// Number of work items in this slice + pub count: usize, + /// Grid dimensions for this slice + pub grid: (u32, u32, u32), + /// Whether this slice has been executed + pub executed: bool, + /// Execution time in milliseconds (if executed) + pub exec_time_ms: Option, +} + +impl KernelSlice { + /// Create a new slice + pub fn new(id: usize, offset: usize, count: usize, grid: (u32, u32, u32)) -> Self { + Self { + id, + offset, + count, + grid, + executed: false, + exec_time_ms: None, + } + } + + /// Mark as executed with timing + pub fn complete(&mut self, exec_time_ms: f64) { + self.executed = true; + self.exec_time_ms = Some(exec_time_ms); + } +} + +/// Sliced kernel representation +#[derive(Debug)] +pub struct SlicedKernel { + /// Kernel handle + pub kernel_handle: u64, + /// Block dimensions + pub block: (u32, u32, u32), + /// Shared memory per block + pub shared_mem: u32, + /// Total work items + pub total_items: usize, + /// Slices + pub slices: Vec, + /// Current slice index + pub current_slice: usize, + /// Associated task ID + pub task_id: Option, + /// Priority + pub priority: i32, +} + +impl SlicedKernel { + /// Create a sliced kernel + pub fn new( + kernel_handle: u64, + block: (u32, u32, u32), + shared_mem: u32, + total_items: usize, + slices: Vec, + ) -> Self { + Self { + kernel_handle, + block, + shared_mem, + total_items, + slices, + current_slice: 0, + task_id: None, + priority: 0, + } + } + + /// Set task ID + pub fn with_task(mut self, task_id: String) -> Self { + self.task_id = Some(task_id); + self + } + + /// Set priority + pub fn with_priority(mut self, priority: i32) -> Self { + self.priority = priority; + self + } + + /// Get next slice to execute + pub fn next_slice(&mut self) -> Option<&KernelSlice> { + if self.current_slice < self.slices.len() { + let slice = &self.slices[self.current_slice]; + Some(slice) + } else { + None + } + } + + /// Mark current slice as completed + pub fn complete_slice(&mut self, exec_time_ms: f64) { + if self.current_slice < self.slices.len() { + self.slices[self.current_slice].complete(exec_time_ms); + self.current_slice += 1; + } + } + + /// Check if all slices are executed + pub fn is_complete(&self) -> bool { + self.current_slice >= self.slices.len() + } + + /// Get completion progress (0.0 - 1.0) + pub fn progress(&self) -> f64 { + if self.slices.is_empty() { + 1.0 + } else { + self.current_slice as f64 / self.slices.len() as f64 + } + } + + /// Get total execution time so far + pub fn total_exec_time_ms(&self) -> f64 { + self.slices + .iter() + .filter_map(|s| s.exec_time_ms) + .sum() + } + + /// Get number of remaining slices + pub fn remaining_slices(&self) -> usize { + self.slices.len().saturating_sub(self.current_slice) + } +} + +/// Slice scheduler for interleaving slices across tasks +#[derive(Debug)] +pub struct SliceScheduler { + config: SliceConfig, + /// Queue of sliced kernels + queue: VecDeque, + /// Statistics + total_slices: usize, + completed_slices: usize, + total_kernels: usize, + completed_kernels: usize, +} + +impl SliceScheduler { + /// Create a new slice scheduler + pub fn new(config: SliceConfig) -> Self { + Self { + config, + queue: VecDeque::new(), + total_slices: 0, + completed_slices: 0, + total_kernels: 0, + completed_kernels: 0, + } + } + + /// Create with default config + pub fn with_defaults() -> Self { + Self::new(SliceConfig::default()) + } + + /// Slice a kernel and add to queue + pub fn submit(&mut self, kernel_handle: u64, total_items: usize, block: (u32, u32, u32), shared_mem: u32) -> usize { + let slices = self.create_slices(total_items, block.0); + let num_slices = slices.len(); + + let sliced = SlicedKernel::new(kernel_handle, block, shared_mem, total_items, slices); + self.queue.push_back(sliced); + self.total_slices += num_slices; + self.total_kernels += 1; + + num_slices + } + + /// Submit with task and priority + pub fn submit_for_task( + &mut self, + task_id: String, + kernel_handle: u64, + total_items: usize, + block: (u32, u32, u32), + shared_mem: u32, + priority: i32, + ) -> usize { + let slices = self.create_slices(total_items, block.0); + let num_slices = slices.len(); + + let sliced = SlicedKernel::new(kernel_handle, block, shared_mem, total_items, slices) + .with_task(task_id) + .with_priority(priority); + + // Insert sorted by priority (higher first) + let pos = self.queue.iter().position(|k| k.priority < priority); + match pos { + Some(i) => self.queue.insert(i, sliced), + None => self.queue.push_back(sliced), + } + + self.total_slices += num_slices; + self.total_kernels += 1; + num_slices + } + + /// Create slices for a kernel + fn create_slices(&self, total_items: usize, block_x: u32) -> Vec { + let items_per_block = block_x as usize; + + // Calculate number of slices + let num_slices = (total_items / self.config.max_items_per_slice).max(1); + let num_slices = num_slices.clamp(self.config.min_slices, self.config.max_slices); + + let items_per_slice = (total_items + num_slices - 1) / num_slices; + let _blocks_per_slice = (items_per_slice + items_per_block - 1) / items_per_block; + + let mut slices = Vec::new(); + let mut offset = 0; + + for id in 0..num_slices { + let remaining = total_items.saturating_sub(offset); + if remaining == 0 { + break; + } + + let count = remaining.min(items_per_slice); + let grid_x = ((count + items_per_block - 1) / items_per_block) as u32; + + slices.push(KernelSlice::new(id, offset, count, (grid_x, 1, 1))); + offset += count; + } + + slices + } + + /// Get next slice to execute (round-robin fair scheduling) + pub fn get_next_slice(&mut self) -> Option { + if self.queue.is_empty() { + return None; + } + + // Rotate to front kernel and get slice + let kernel = self.queue.front_mut()?; + let slice = kernel.next_slice()?.clone(); + + Some(SliceInfo { + kernel_handle: kernel.kernel_handle, + block: kernel.block, + shared_mem: kernel.shared_mem, + slice_id: slice.id, + offset: slice.offset, + count: slice.count, + grid: slice.grid, + task_id: kernel.task_id.clone(), + priority: kernel.priority, + }) + } + + /// Complete the current slice of the front kernel + pub fn complete_slice(&mut self, exec_time_ms: f64) { + if let Some(kernel) = self.queue.front_mut() { + kernel.complete_slice(exec_time_ms); + self.completed_slices += 1; + + // If kernel is complete, remove it and rotate + if kernel.is_complete() { + self.queue.pop_front(); + self.completed_kernels += 1; + } else { + // Rotate to back for fairness + if let Some(k) = self.queue.pop_front() { + self.queue.push_back(k); + } + } + } + } + + /// Get number of pending slices + pub fn pending_slices(&self) -> usize { + self.queue.iter().map(|k| k.remaining_slices()).sum() + } + + /// Get number of pending kernels + pub fn pending_kernels(&self) -> usize { + self.queue.len() + } + + /// Get statistics + pub fn stats(&self) -> SliceStats { + SliceStats { + total_slices: self.total_slices, + completed_slices: self.completed_slices, + pending_slices: self.pending_slices(), + total_kernels: self.total_kernels, + completed_kernels: self.completed_kernels, + pending_kernels: self.pending_kernels(), + } + } + + /// Clear all state + pub fn clear(&mut self) { + self.queue.clear(); + self.total_slices = 0; + self.completed_slices = 0; + self.total_kernels = 0; + self.completed_kernels = 0; + } + + /// Get config + pub fn config(&self) -> &SliceConfig { + &self.config + } +} + +/// Information about a slice to execute +#[derive(Debug, Clone)] +pub struct SliceInfo { + /// Kernel handle + pub kernel_handle: u64, + /// Block dimensions + pub block: (u32, u32, u32), + /// Shared memory + pub shared_mem: u32, + /// Slice ID + pub slice_id: usize, + /// Offset in work items + pub offset: usize, + /// Count of work items + pub count: usize, + /// Grid dimensions for this slice + pub grid: (u32, u32, u32), + /// Associated task ID + pub task_id: Option, + /// Priority + pub priority: i32, +} + +/// Slice scheduler statistics +#[derive(Debug, Clone, Default)] +pub struct SliceStats { + /// Total slices created + pub total_slices: usize, + /// Completed slices + pub completed_slices: usize, + /// Pending slices + pub pending_slices: usize, + /// Total kernels submitted + pub total_kernels: usize, + /// Completed kernels + pub completed_kernels: usize, + /// Pending kernels + pub pending_kernels: usize, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_slice_config() { + let config = SliceConfig::with_max_items(1024) + .max_duration(2.0) + .min_slices(2); + + assert_eq!(config.max_items_per_slice, 1024); + assert!((config.max_duration_ms - 2.0).abs() < 0.001); + assert_eq!(config.min_slices, 2); + } + + #[test] + fn test_slicing() { + let config = SliceConfig::with_max_items(1000).min_slices(1).max_slices(10); + let mut scheduler = SliceScheduler::new(config); + + // Submit kernel with 5000 items + let num_slices = scheduler.submit(0xDEADBEEF, 5000, (256, 1, 1), 0); + + assert!(num_slices >= 1); + assert!(num_slices <= 10); + + let stats = scheduler.stats(); + assert_eq!(stats.total_kernels, 1); + assert_eq!(stats.total_slices, num_slices); + } + + #[test] + fn test_slice_execution() { + let config = SliceConfig::with_max_items(100).min_slices(1); + let mut scheduler = SliceScheduler::new(config); + + scheduler.submit(0xDEADBEEF, 500, (256, 1, 1), 0); + + // Execute all slices + while let Some(slice_info) = scheduler.get_next_slice() { + assert!(slice_info.count > 0); + scheduler.complete_slice(0.1); + } + + let stats = scheduler.stats(); + assert_eq!(stats.pending_slices, 0); + assert_eq!(stats.completed_kernels, 1); + } + + #[test] + fn test_fair_scheduling() { + let config = SliceConfig::with_max_items(100).min_slices(2); + let mut scheduler = SliceScheduler::new(config); + + // Submit two kernels + scheduler.submit(0x1111, 200, (256, 1, 1), 0); + scheduler.submit(0x2222, 200, (256, 1, 1), 0); + + // First two slices should be from different kernels (round-robin) + let slice1 = scheduler.get_next_slice().unwrap(); + scheduler.complete_slice(0.1); + + let slice2 = scheduler.get_next_slice().unwrap(); + + // After completing first kernel's slice, it should rotate to second + // (depends on implementation details, but should be fair) + assert!(slice1.kernel_handle != slice2.kernel_handle || scheduler.pending_kernels() == 1); + } + + #[test] + fn test_priority_scheduling() { + let config = SliceConfig::with_max_items(1000); + let mut scheduler = SliceScheduler::new(config); + + // Submit low priority first + scheduler.submit_for_task("low".into(), 0x1111, 1000, (256, 1, 1), 0, 0); + // Submit high priority second + scheduler.submit_for_task("high".into(), 0x2222, 1000, (256, 1, 1), 0, 100); + + // High priority should come first + let slice = scheduler.get_next_slice().unwrap(); + assert_eq!(slice.priority, 100); + assert_eq!(slice.task_id, Some("high".into())); + } +} diff --git a/rust/pygpukit-core/src/lib.rs b/rust/pygpukit-core/src/lib.rs index c1f8286..169609c 100644 --- a/rust/pygpukit-core/src/lib.rs +++ b/rust/pygpukit-core/src/lib.rs @@ -5,6 +5,7 @@ //! - Task scheduler with bandwidth pacing //! - Async memory transfer engine with separate streams //! - Kernel dispatch controller with stream management +//! - Kernel pacing engine with bandwidth control pub mod memory; pub mod scheduler; @@ -12,6 +13,19 @@ pub mod transfer; pub mod dispatch; pub use memory::{MemoryBlock, MemoryPool, PoolStats, MemoryError}; -pub use scheduler::{TaskState, TaskPolicy, TaskMeta, Scheduler, SchedulerStats, TaskStats}; -pub use transfer::{TransferType, TransferOp, TransferState, AsyncTransferEngine, StreamType, TransferStats}; -pub use dispatch::{KernelDispatcher, KernelLaunchRequest, KernelState, DispatchStats, LaunchConfig}; +pub use scheduler::{ + TaskState, TaskPolicy, TaskMeta, Scheduler, SchedulerStats, TaskStats, + AdmissionController, AdmissionConfig, AdmissionDecision, AdmissionStats, RejectReason, + QosClass, QosPolicy, QosTaskMeta, QosEvaluation, QosPolicyEvaluator, QosStats, ResourceRequirements, + PartitionManager, PartitionConfig, Partition, PartitionLimits, PartitionUsage, PartitionStats, PartitionError, +}; +pub use transfer::{ + TransferType, TransferOp, TransferState, AsyncTransferEngine, StreamType, TransferStats, + PinnedMemoryManager, PinnedPoolConfig, PinnedBlock, PinnedStats, PinnedError, +}; +pub use dispatch::{ + KernelDispatcher, KernelLaunchRequest, KernelState, DispatchStats, LaunchConfig, + KernelPacingEngine, PacingConfig, PacingDecision, PacingStats, StreamPacingStats, + SliceScheduler, SliceConfig, SlicedKernel, KernelSlice, SliceInfo, SliceStats, + KernelCache, CacheConfig, CachedKernel, CompileOptions, CacheStats, +}; diff --git a/rust/pygpukit-core/src/scheduler/admission.rs b/rust/pygpukit-core/src/scheduler/admission.rs new file mode 100644 index 0000000..401cde6 --- /dev/null +++ b/rust/pygpukit-core/src/scheduler/admission.rs @@ -0,0 +1,538 @@ +//! Admission Control for GPU task scheduling +//! +//! Implements deterministic admission pipeline that enforces +//! memory quotas and bandwidth reservation before scheduling. + +use crate::scheduler::task::TaskMeta; + +/// Reason for admission rejection +#[derive(Debug, Clone, PartialEq)] +pub enum RejectReason { + /// Insufficient memory quota + InsufficientMemory { + requested: usize, + available: usize, + }, + /// Bandwidth quota exceeded + BandwidthExceeded { + requested_bw: f64, + available_bw: f64, + }, + /// Too many pending tasks + QueueFull { + current: usize, + max: usize, + }, + /// Task dependencies not satisfiable + UnsatisfiableDependencies { + missing: Vec, + }, + /// Custom rejection reason + Custom(String), +} + +/// Result of admission control decision +#[derive(Debug, Clone, PartialEq)] +pub enum AdmissionDecision { + /// Task is admitted and can be scheduled + Admit { + /// Reserved memory for this task + reserved_memory: usize, + /// Reserved bandwidth fraction (0.0 - 1.0) + reserved_bandwidth: f64, + }, + /// Task is rejected + Reject { + reason: RejectReason, + }, + /// Task is queued for later (best-effort) + Queue { + /// Position in queue + position: usize, + /// Estimated wait time in milliseconds + estimated_wait_ms: f64, + }, +} + +impl AdmissionDecision { + /// Create an admit decision with memory reservation + pub fn admit(reserved_memory: usize) -> Self { + Self::Admit { + reserved_memory, + reserved_bandwidth: 0.0, + } + } + + /// Create an admit decision with memory and bandwidth reservation + pub fn admit_with_bandwidth(reserved_memory: usize, reserved_bandwidth: f64) -> Self { + Self::Admit { + reserved_memory, + reserved_bandwidth, + } + } + + /// Create a reject decision due to insufficient memory + pub fn reject_memory(requested: usize, available: usize) -> Self { + Self::Reject { + reason: RejectReason::InsufficientMemory { + requested, + available, + }, + } + } + + /// Create a reject decision due to bandwidth exceeded + pub fn reject_bandwidth(requested_bw: f64, available_bw: f64) -> Self { + Self::Reject { + reason: RejectReason::BandwidthExceeded { + requested_bw, + available_bw, + }, + } + } + + /// Create a reject decision due to queue full + pub fn reject_queue_full(current: usize, max: usize) -> Self { + Self::Reject { + reason: RejectReason::QueueFull { current, max }, + } + } + + /// Create a reject decision due to unsatisfiable dependencies + pub fn reject_dependencies(missing: Vec) -> Self { + Self::Reject { + reason: RejectReason::UnsatisfiableDependencies { missing }, + } + } + + /// Create a queue decision for best-effort tasks + pub fn queue(position: usize, estimated_wait_ms: f64) -> Self { + Self::Queue { + position, + estimated_wait_ms, + } + } + + /// Check if the decision is an admission + pub fn is_admitted(&self) -> bool { + matches!(self, Self::Admit { .. }) + } + + /// Check if the decision is a rejection + pub fn is_rejected(&self) -> bool { + matches!(self, Self::Reject { .. }) + } + + /// Check if the decision is queued + pub fn is_queued(&self) -> bool { + matches!(self, Self::Queue { .. }) + } + + /// Get the rejection reason if rejected + pub fn rejection_reason(&self) -> Option<&RejectReason> { + match self { + Self::Reject { reason } => Some(reason), + _ => None, + } + } +} + +/// Admission control configuration +#[derive(Debug, Clone)] +pub struct AdmissionConfig { + /// Total GPU memory available (bytes) + pub total_memory: usize, + /// Maximum memory that can be reserved (bytes) + pub max_reserved_memory: usize, + /// Maximum number of pending tasks + pub max_pending_tasks: usize, + /// Total bandwidth available (0.0 - 1.0) + pub total_bandwidth: f64, + /// Enable best-effort queueing for tasks that exceed quotas + pub enable_best_effort: bool, + /// Memory overcommit ratio (1.0 = no overcommit) + pub memory_overcommit_ratio: f64, +} + +impl Default for AdmissionConfig { + fn default() -> Self { + Self { + total_memory: usize::MAX, + max_reserved_memory: usize::MAX, + max_pending_tasks: 10000, + total_bandwidth: 1.0, + enable_best_effort: true, + memory_overcommit_ratio: 1.0, + } + } +} + +impl AdmissionConfig { + /// Create a new admission config with memory limit + pub fn with_memory(total_memory: usize) -> Self { + Self { + total_memory, + max_reserved_memory: total_memory, + ..Default::default() + } + } + + /// Set maximum pending tasks + pub fn max_pending(mut self, max: usize) -> Self { + self.max_pending_tasks = max; + self + } + + /// Set bandwidth limit + pub fn bandwidth(mut self, bw: f64) -> Self { + self.total_bandwidth = bw; + self + } + + /// Enable/disable best-effort queueing + pub fn best_effort(mut self, enable: bool) -> Self { + self.enable_best_effort = enable; + self + } + + /// Set memory overcommit ratio + pub fn overcommit(mut self, ratio: f64) -> Self { + self.memory_overcommit_ratio = ratio; + self + } +} + +/// Admission controller state +#[derive(Debug)] +pub struct AdmissionController { + config: AdmissionConfig, + /// Currently reserved memory + reserved_memory: usize, + /// Currently reserved bandwidth + reserved_bandwidth: f64, + /// Number of pending tasks + pending_count: usize, + /// Total admitted tasks + admitted_count: usize, + /// Total rejected tasks + rejected_count: usize, + /// Total queued tasks (best-effort) + queued_count: usize, +} + +impl AdmissionController { + /// Create a new admission controller + pub fn new(config: AdmissionConfig) -> Self { + Self { + config, + reserved_memory: 0, + reserved_bandwidth: 0.0, + pending_count: 0, + admitted_count: 0, + rejected_count: 0, + queued_count: 0, + } + } + + /// Create with default config + pub fn with_defaults() -> Self { + Self::new(AdmissionConfig::default()) + } + + /// Create with memory limit + pub fn with_memory(total_memory: usize) -> Self { + Self::new(AdmissionConfig::with_memory(total_memory)) + } + + /// Evaluate admission for a task + /// + /// This is a deterministic evaluation that does not modify state. + /// Call `reserve()` after admission to actually reserve resources. + pub fn evaluate(&self, task: &TaskMeta) -> AdmissionDecision { + // Check queue capacity + if self.pending_count >= self.config.max_pending_tasks { + return AdmissionDecision::reject_queue_full( + self.pending_count, + self.config.max_pending_tasks, + ); + } + + // Calculate effective memory limit with overcommit + let effective_memory_limit = + (self.config.max_reserved_memory as f64 * self.config.memory_overcommit_ratio) as usize; + let available_memory = effective_memory_limit.saturating_sub(self.reserved_memory); + + // Check memory quota + if task.memory_estimate > available_memory { + if self.config.enable_best_effort { + // Queue for later + let estimated_wait = self.estimate_wait_time(task.memory_estimate); + return AdmissionDecision::queue(self.pending_count, estimated_wait); + } else { + return AdmissionDecision::reject_memory(task.memory_estimate, available_memory); + } + } + + // Calculate bandwidth requirement (based on memory estimate) + // Simplified model: bandwidth proportional to memory (adjusted for overcommit) + let bandwidth_estimate = if effective_memory_limit > 0 { + task.memory_estimate as f64 / effective_memory_limit as f64 + } else { + 0.0 + }; + + let available_bandwidth = self.config.total_bandwidth - self.reserved_bandwidth; + + // Check bandwidth quota + if bandwidth_estimate > available_bandwidth { + if self.config.enable_best_effort { + let estimated_wait = self.estimate_wait_time(task.memory_estimate); + return AdmissionDecision::queue(self.pending_count, estimated_wait); + } else { + return AdmissionDecision::reject_bandwidth(bandwidth_estimate, available_bandwidth); + } + } + + // Admit the task + AdmissionDecision::admit_with_bandwidth(task.memory_estimate, bandwidth_estimate) + } + + /// Reserve resources for an admitted task + /// + /// Call this after receiving an `Admit` decision. + pub fn reserve(&mut self, decision: &AdmissionDecision) -> bool { + match decision { + AdmissionDecision::Admit { + reserved_memory, + reserved_bandwidth, + } => { + self.reserved_memory += reserved_memory; + self.reserved_bandwidth += reserved_bandwidth; + self.pending_count += 1; + self.admitted_count += 1; + true + } + AdmissionDecision::Queue { .. } => { + self.pending_count += 1; + self.queued_count += 1; + true + } + AdmissionDecision::Reject { .. } => { + self.rejected_count += 1; + false + } + } + } + + /// Release resources when a task completes + pub fn release(&mut self, memory: usize, bandwidth: f64) { + self.reserved_memory = self.reserved_memory.saturating_sub(memory); + self.reserved_bandwidth = (self.reserved_bandwidth - bandwidth).max(0.0); + self.pending_count = self.pending_count.saturating_sub(1); + } + + /// Admit a task (evaluate + reserve in one call) + pub fn admit(&mut self, task: &TaskMeta) -> AdmissionDecision { + let decision = self.evaluate(task); + self.reserve(&decision); + decision + } + + /// Estimate wait time for a task requiring given memory + fn estimate_wait_time(&self, _memory_needed: usize) -> f64 { + // Simple heuristic: assume tasks complete at ~100MB/s throughput + let memory_throughput = 100.0 * 1024.0 * 1024.0; // 100 MB/s + let wait_seconds = self.reserved_memory as f64 / memory_throughput; + wait_seconds * 1000.0 // Convert to ms + } + + /// Get current reserved memory + pub fn reserved_memory(&self) -> usize { + self.reserved_memory + } + + /// Get current reserved bandwidth + pub fn reserved_bandwidth(&self) -> f64 { + self.reserved_bandwidth + } + + /// Get available memory + pub fn available_memory(&self) -> usize { + let effective = (self.config.max_reserved_memory as f64 + * self.config.memory_overcommit_ratio) as usize; + effective.saturating_sub(self.reserved_memory) + } + + /// Get available bandwidth + pub fn available_bandwidth(&self) -> f64 { + (self.config.total_bandwidth - self.reserved_bandwidth).max(0.0) + } + + /// Get admission statistics + pub fn stats(&self) -> AdmissionStats { + AdmissionStats { + admitted_count: self.admitted_count, + rejected_count: self.rejected_count, + queued_count: self.queued_count, + pending_count: self.pending_count, + reserved_memory: self.reserved_memory, + reserved_bandwidth: self.reserved_bandwidth, + available_memory: self.available_memory(), + available_bandwidth: self.available_bandwidth(), + } + } + + /// Reset the controller state + pub fn reset(&mut self) { + self.reserved_memory = 0; + self.reserved_bandwidth = 0.0; + self.pending_count = 0; + self.admitted_count = 0; + self.rejected_count = 0; + self.queued_count = 0; + } + + /// Get the config + pub fn config(&self) -> &AdmissionConfig { + &self.config + } +} + +/// Admission control statistics +#[derive(Debug, Clone, Default)] +pub struct AdmissionStats { + /// Total admitted tasks + pub admitted_count: usize, + /// Total rejected tasks + pub rejected_count: usize, + /// Total queued tasks (best-effort) + pub queued_count: usize, + /// Current pending tasks + pub pending_count: usize, + /// Currently reserved memory + pub reserved_memory: usize, + /// Currently reserved bandwidth + pub reserved_bandwidth: f64, + /// Available memory + pub available_memory: usize, + /// Available bandwidth + pub available_bandwidth: f64, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_admission_simple() { + let mut controller = AdmissionController::with_memory(1000); + + let task = TaskMeta::with_memory("task-1".into(), "Test".into(), 500); + let decision = controller.admit(&task); + + assert!(decision.is_admitted()); + assert_eq!(controller.reserved_memory(), 500); + } + + #[test] + fn test_admission_reject_memory() { + let config = AdmissionConfig::with_memory(1000).best_effort(false); + let mut controller = AdmissionController::new(config); + + let task = TaskMeta::with_memory("task-1".into(), "Test".into(), 1500); + let decision = controller.admit(&task); + + assert!(decision.is_rejected()); + match decision.rejection_reason() { + Some(RejectReason::InsufficientMemory { requested, available }) => { + assert_eq!(*requested, 1500); + assert_eq!(*available, 1000); + } + _ => panic!("Expected InsufficientMemory rejection"), + } + } + + #[test] + fn test_admission_queue_best_effort() { + let config = AdmissionConfig::with_memory(1000).best_effort(true); + let mut controller = AdmissionController::new(config); + + // Fill memory + let task1 = TaskMeta::with_memory("task-1".into(), "Test".into(), 800); + controller.admit(&task1); + + // This should be queued + let task2 = TaskMeta::with_memory("task-2".into(), "Test".into(), 500); + let decision = controller.admit(&task2); + + assert!(decision.is_queued()); + } + + #[test] + fn test_admission_queue_full() { + let config = AdmissionConfig::with_memory(10000).max_pending(2); + let mut controller = AdmissionController::new(config); + + // Submit 2 tasks + let task1 = TaskMeta::with_memory("task-1".into(), "Test".into(), 100); + let task2 = TaskMeta::with_memory("task-2".into(), "Test".into(), 100); + controller.admit(&task1); + controller.admit(&task2); + + // Third should be rejected + let task3 = TaskMeta::with_memory("task-3".into(), "Test".into(), 100); + let decision = controller.admit(&task3); + + assert!(decision.is_rejected()); + match decision.rejection_reason() { + Some(RejectReason::QueueFull { current, max }) => { + assert_eq!(*current, 2); + assert_eq!(*max, 2); + } + _ => panic!("Expected QueueFull rejection"), + } + } + + #[test] + fn test_admission_release() { + let mut controller = AdmissionController::with_memory(1000); + + let task = TaskMeta::with_memory("task-1".into(), "Test".into(), 500); + let decision = controller.admit(&task); + + assert!(decision.is_admitted()); + assert_eq!(controller.reserved_memory(), 500); + + // Release + controller.release(500, 0.0); + assert_eq!(controller.reserved_memory(), 0); + assert_eq!(controller.available_memory(), 1000); + } + + #[test] + fn test_admission_overcommit() { + let config = AdmissionConfig::with_memory(1000).overcommit(1.5); + let mut controller = AdmissionController::new(config); + + // Can admit up to 1500 bytes + let task = TaskMeta::with_memory("task-1".into(), "Test".into(), 1200); + let decision = controller.admit(&task); + + assert!(decision.is_admitted()); + } + + #[test] + fn test_admission_stats() { + let mut controller = AdmissionController::with_memory(1000); + + // Admit one + let task1 = TaskMeta::with_memory("task-1".into(), "Test".into(), 500); + controller.admit(&task1); + + let stats = controller.stats(); + assert_eq!(stats.admitted_count, 1); + assert_eq!(stats.pending_count, 1); + assert_eq!(stats.reserved_memory, 500); + assert_eq!(stats.available_memory, 500); + } +} diff --git a/rust/pygpukit-core/src/scheduler/core.rs b/rust/pygpukit-core/src/scheduler/core.rs index c544de9..24cf949 100644 --- a/rust/pygpukit-core/src/scheduler/core.rs +++ b/rust/pygpukit-core/src/scheduler/core.rs @@ -6,6 +6,7 @@ use std::collections::{HashMap, VecDeque}; use std::time::{SystemTime, UNIX_EPOCH}; use parking_lot::RwLock; use crate::scheduler::task::{TaskMeta, TaskState, TaskStats}; +use crate::scheduler::admission::{AdmissionController, AdmissionConfig, AdmissionDecision}; /// Scheduler statistics #[derive(Debug, Clone, Default)] @@ -46,6 +47,8 @@ struct SchedulerInner { total_wait_time: f64, total_exec_time: f64, completed_count: usize, + /// Admission controller + admission: AdmissionController, } /// Thread-safe task scheduler with bandwidth pacing. @@ -81,6 +84,11 @@ impl Scheduler { /// * `sched_tick_ms` - Scheduling tick interval in milliseconds /// * `window_ms` - Bandwidth pacing window in milliseconds pub fn new(total_memory: Option, sched_tick_ms: f64, window_ms: f64) -> Self { + let admission_config = match total_memory { + Some(mem) => AdmissionConfig::with_memory(mem), + None => AdmissionConfig::default(), + }; + Self { total_memory, sched_tick_ms, @@ -93,10 +101,54 @@ impl Scheduler { total_wait_time: 0.0, total_exec_time: 0.0, completed_count: 0, + admission: AdmissionController::new(admission_config), }), } } + /// Evaluate admission for a task without submitting it. + /// + /// This performs a dry-run admission check to determine if + /// a task would be admitted, queued, or rejected. + pub fn evaluate_admission(&self, task: &TaskMeta) -> AdmissionDecision { + let inner = self.inner.read(); + inner.admission.evaluate(task) + } + + /// Admit a task through the admission control pipeline. + /// + /// Returns an AdmissionDecision indicating whether the task + /// was admitted, queued for best-effort, or rejected. + /// + /// If admitted or queued, the task is automatically submitted. + pub fn admit(&self, task: TaskMeta) -> AdmissionDecision { + let mut inner = self.inner.write(); + let decision = inner.admission.admit(&task); + + match &decision { + AdmissionDecision::Admit { reserved_memory, .. } => { + // Task admitted - add to scheduler + let task_id = task.id.clone(); + inner.pending_queue.push_back(task_id.clone()); + inner.reserved_memory += reserved_memory; + inner.tasks.insert(task_id, task); + } + AdmissionDecision::Queue { .. } => { + // Task queued for best-effort - still add to scheduler + let task_id = task.id.clone(); + let memory = task.memory_estimate; + inner.pending_queue.push_back(task_id.clone()); + inner.reserved_memory += memory; + inner.tasks.insert(task_id, task); + } + AdmissionDecision::Reject { .. } => { + // Task rejected - do not add + } + } + + decision + } + /// Submit a task for scheduling. /// /// Memory is reserved immediately upon submission to ensure @@ -379,6 +431,12 @@ impl Scheduler { inner.total_wait_time = 0.0; inner.total_exec_time = 0.0; inner.completed_count = 0; + inner.admission.reset(); + } + + /// Get admission control statistics. + pub fn admission_stats(&self) -> crate::scheduler::admission::AdmissionStats { + self.inner.read().admission.stats() } /// Get scheduling tick interval. diff --git a/rust/pygpukit-core/src/scheduler/mod.rs b/rust/pygpukit-core/src/scheduler/mod.rs index 6be6941..38e72c3 100644 --- a/rust/pygpukit-core/src/scheduler/mod.rs +++ b/rust/pygpukit-core/src/scheduler/mod.rs @@ -4,9 +4,27 @@ //! - Priority-based task execution //! - Bandwidth pacing //! - Memory reservation tracking +//! - Admission control +//! - QoS policy framework +//! - GPU resource partitioning mod task; mod core; +mod admission; +mod qos; +mod partition; pub use task::{TaskState, TaskPolicy, TaskMeta, TaskStats}; pub use core::{Scheduler, SchedulerStats}; +pub use admission::{ + AdmissionController, AdmissionConfig, AdmissionDecision, + AdmissionStats, RejectReason, +}; +pub use qos::{ + QosClass, QosPolicy, QosTaskMeta, QosEvaluation, + QosPolicyEvaluator, QosStats, ResourceRequirements, +}; +pub use partition::{ + PartitionManager, PartitionConfig, Partition, PartitionLimits, + PartitionUsage, PartitionStats, PartitionError, +}; diff --git a/rust/pygpukit-core/src/scheduler/partition.rs b/rust/pygpukit-core/src/scheduler/partition.rs new file mode 100644 index 0000000..4bac629 --- /dev/null +++ b/rust/pygpukit-core/src/scheduler/partition.rs @@ -0,0 +1,782 @@ +//! GPU Partitioning +//! +//! Provides logical partitioning of GPU resources across tasks or tenants. +//! Supports partitioning of: +//! - Compute resources (SM units) +//! - Memory quota +//! - Bandwidth allocation +//! - Stream capacity + +use std::collections::HashMap; + +/// Partition resource limits +#[derive(Debug, Clone)] +pub struct PartitionLimits { + /// Memory quota in bytes (0 = unlimited) + pub memory_quota: usize, + /// Compute share (0.0 - 1.0, fraction of GPU) + pub compute_share: f64, + /// Bandwidth share (0.0 - 1.0) + pub bandwidth_share: f64, + /// Maximum concurrent streams + pub max_streams: usize, + /// Maximum pending kernels + pub max_pending_kernels: usize, + /// Maximum pending transfers + pub max_pending_transfers: usize, +} + +impl Default for PartitionLimits { + fn default() -> Self { + Self { + memory_quota: 0, // Unlimited + compute_share: 1.0, + bandwidth_share: 1.0, + max_streams: 16, + max_pending_kernels: 1024, + max_pending_transfers: 256, + } + } +} + +impl PartitionLimits { + /// Create with memory quota + pub fn with_memory(memory_quota: usize) -> Self { + Self { + memory_quota, + ..Default::default() + } + } + + /// Create with compute share + pub fn with_compute(compute_share: f64) -> Self { + Self { + compute_share: compute_share.clamp(0.0, 1.0), + ..Default::default() + } + } + + /// Set memory quota + pub fn memory(mut self, quota: usize) -> Self { + self.memory_quota = quota; + self + } + + /// Set compute share + pub fn compute(mut self, share: f64) -> Self { + self.compute_share = share.clamp(0.0, 1.0); + self + } + + /// Set bandwidth share + pub fn bandwidth(mut self, share: f64) -> Self { + self.bandwidth_share = share.clamp(0.0, 1.0); + self + } + + /// Set max streams + pub fn streams(mut self, max: usize) -> Self { + self.max_streams = max; + self + } +} + +/// Partition resource usage +#[derive(Debug, Clone, Default)] +pub struct PartitionUsage { + /// Current memory usage in bytes + pub memory_used: usize, + /// Active streams + pub active_streams: usize, + /// Pending kernels + pub pending_kernels: usize, + /// Pending transfers + pub pending_transfers: usize, + /// Total kernels executed + pub total_kernels: usize, + /// Total transfers completed + pub total_transfers: usize, + /// Compute time in milliseconds + pub compute_time_ms: f64, +} + +impl PartitionUsage { + /// Check if memory would exceed quota + pub fn would_exceed_memory(&self, limits: &PartitionLimits, size: usize) -> bool { + limits.memory_quota > 0 && self.memory_used + size > limits.memory_quota + } + + /// Check if stream limit reached + pub fn stream_limit_reached(&self, limits: &PartitionLimits) -> bool { + self.active_streams >= limits.max_streams + } + + /// Check if kernel limit reached + pub fn kernel_limit_reached(&self, limits: &PartitionLimits) -> bool { + self.pending_kernels >= limits.max_pending_kernels + } + + /// Check if transfer limit reached + pub fn transfer_limit_reached(&self, limits: &PartitionLimits) -> bool { + self.pending_transfers >= limits.max_pending_transfers + } +} + +/// A GPU partition +#[derive(Debug, Clone)] +pub struct Partition { + /// Partition ID + pub id: String, + /// Partition name + pub name: String, + /// Resource limits + pub limits: PartitionLimits, + /// Current usage + pub usage: PartitionUsage, + /// Associated task IDs + pub tasks: Vec, + /// Creation timestamp + pub created_at: f64, + /// Whether partition is enabled + pub enabled: bool, +} + +impl Partition { + /// Create a new partition + pub fn new(id: String, name: String, limits: PartitionLimits) -> Self { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0); + Self { + id, + name, + limits, + usage: PartitionUsage::default(), + tasks: Vec::new(), + created_at: now, + enabled: true, + } + } + + /// Check if memory allocation is allowed + pub fn can_allocate_memory(&self, size: usize) -> bool { + !self.usage.would_exceed_memory(&self.limits, size) + } + + /// Check if stream creation is allowed + pub fn can_create_stream(&self) -> bool { + !self.usage.stream_limit_reached(&self.limits) + } + + /// Check if kernel submission is allowed + pub fn can_submit_kernel(&self) -> bool { + !self.usage.kernel_limit_reached(&self.limits) + } + + /// Check if transfer submission is allowed + pub fn can_submit_transfer(&self) -> bool { + !self.usage.transfer_limit_reached(&self.limits) + } + + /// Allocate memory + pub fn allocate_memory(&mut self, size: usize) -> bool { + if self.can_allocate_memory(size) { + self.usage.memory_used += size; + true + } else { + false + } + } + + /// Free memory + pub fn free_memory(&mut self, size: usize) { + self.usage.memory_used = self.usage.memory_used.saturating_sub(size); + } + + /// Register a stream + pub fn register_stream(&mut self) -> bool { + if self.can_create_stream() { + self.usage.active_streams += 1; + true + } else { + false + } + } + + /// Release a stream + pub fn release_stream(&mut self) { + self.usage.active_streams = self.usage.active_streams.saturating_sub(1); + } + + /// Submit a kernel + pub fn submit_kernel(&mut self) -> bool { + if self.can_submit_kernel() { + self.usage.pending_kernels += 1; + true + } else { + false + } + } + + /// Complete a kernel + pub fn complete_kernel(&mut self, exec_time_ms: f64) { + self.usage.pending_kernels = self.usage.pending_kernels.saturating_sub(1); + self.usage.total_kernels += 1; + self.usage.compute_time_ms += exec_time_ms; + } + + /// Submit a transfer + pub fn submit_transfer(&mut self) -> bool { + if self.can_submit_transfer() { + self.usage.pending_transfers += 1; + true + } else { + false + } + } + + /// Complete a transfer + pub fn complete_transfer(&mut self) { + self.usage.pending_transfers = self.usage.pending_transfers.saturating_sub(1); + self.usage.total_transfers += 1; + } + + /// Add a task to this partition + pub fn add_task(&mut self, task_id: String) { + if !self.tasks.contains(&task_id) { + self.tasks.push(task_id); + } + } + + /// Remove a task from this partition + pub fn remove_task(&mut self, task_id: &str) { + self.tasks.retain(|t| t != task_id); + } + + /// Check if task belongs to this partition + pub fn has_task(&self, task_id: &str) -> bool { + self.tasks.iter().any(|t| t == task_id) + } + + /// Get memory utilization (0.0 - 1.0) + pub fn memory_utilization(&self) -> f64 { + if self.limits.memory_quota > 0 { + self.usage.memory_used as f64 / self.limits.memory_quota as f64 + } else { + 0.0 + } + } +} + +/// Partition manager configuration +#[derive(Debug, Clone)] +pub struct PartitionConfig { + /// Total GPU memory available for partitioning + pub total_memory: usize, + /// Total compute capacity (normalized to 1.0) + pub total_compute: f64, + /// Total bandwidth capacity (normalized to 1.0) + pub total_bandwidth: f64, + /// Allow overcommit + pub allow_overcommit: bool, + /// Overcommit ratio (1.0 = no overcommit) + pub overcommit_ratio: f64, +} + +impl Default for PartitionConfig { + fn default() -> Self { + Self { + total_memory: 8 * 1024 * 1024 * 1024, // 8GB default + total_compute: 1.0, + total_bandwidth: 1.0, + allow_overcommit: false, + overcommit_ratio: 1.0, + } + } +} + +impl PartitionConfig { + /// Create with total memory + pub fn with_memory(total_memory: usize) -> Self { + Self { + total_memory, + ..Default::default() + } + } + + /// Enable overcommit + pub fn overcommit(mut self, ratio: f64) -> Self { + self.allow_overcommit = true; + self.overcommit_ratio = ratio; + self + } +} + +/// Partition manager statistics +#[derive(Debug, Clone, Default)] +pub struct PartitionStats { + /// Total partitions + pub partition_count: usize, + /// Active partitions + pub active_partitions: usize, + /// Total memory allocated across partitions + pub total_memory_allocated: usize, + /// Total compute share allocated + pub total_compute_allocated: f64, + /// Total bandwidth allocated + pub total_bandwidth_allocated: f64, + /// Available memory + pub available_memory: usize, + /// Available compute + pub available_compute: f64, + /// Available bandwidth + pub available_bandwidth: f64, +} + +/// Partition manager +/// +/// Manages GPU resource partitions for multi-tenant or multi-task isolation. +#[derive(Debug)] +pub struct PartitionManager { + config: PartitionConfig, + /// Partitions by ID + partitions: HashMap, + /// Task to partition mapping + task_partition: HashMap, + /// Default partition ID + default_partition: Option, +} + +impl PartitionManager { + /// Create a new partition manager + pub fn new(config: PartitionConfig) -> Self { + Self { + config, + partitions: HashMap::new(), + task_partition: HashMap::new(), + default_partition: None, + } + } + + /// Create with defaults + pub fn with_defaults() -> Self { + Self::new(PartitionConfig::default()) + } + + /// Create with total memory + pub fn with_memory(total_memory: usize) -> Self { + Self::new(PartitionConfig::with_memory(total_memory)) + } + + /// Create a new partition + pub fn create_partition(&mut self, id: &str, name: &str, limits: PartitionLimits) -> Result<(), PartitionError> { + if self.partitions.contains_key(id) { + return Err(PartitionError::AlreadyExists { id: id.into() }); + } + + // Validate limits + self.validate_limits(&limits)?; + + let partition = Partition::new(id.into(), name.into(), limits); + self.partitions.insert(id.into(), partition); + + // Set as default if first + if self.default_partition.is_none() { + self.default_partition = Some(id.into()); + } + + Ok(()) + } + + /// Validate partition limits against available resources + fn validate_limits(&self, limits: &PartitionLimits) -> Result<(), PartitionError> { + let stats = self.stats(); + // Check memory + if limits.memory_quota > 0 && limits.memory_quota > stats.available_memory && !self.config.allow_overcommit { + return Err(PartitionError::InsufficientResources { + resource: "memory".into(), + requested: limits.memory_quota, + available: stats.available_memory, + }); + } + + // Check compute + if limits.compute_share > stats.available_compute && !self.config.allow_overcommit { + return Err(PartitionError::InsufficientResources { + resource: "compute".into(), + requested: (limits.compute_share * 100.0) as usize, + available: (stats.available_compute * 100.0) as usize, + }); + } + + // Check bandwidth + if limits.bandwidth_share > stats.available_bandwidth && !self.config.allow_overcommit { + return Err(PartitionError::InsufficientResources { + resource: "bandwidth".into(), + requested: (limits.bandwidth_share * 100.0) as usize, + available: (stats.available_bandwidth * 100.0) as usize, + }); + } + + Ok(()) + } + + /// Delete a partition + pub fn delete_partition(&mut self, id: &str) -> Result { + let partition = self.partitions.remove(id) + .ok_or(PartitionError::NotFound { id: id.into() })?; + + // Remove task mappings + self.task_partition.retain(|_, p| p != id); + + // Update default + if self.default_partition.as_deref() == Some(id) { + self.default_partition = self.partitions.keys().next().cloned(); + } + + Ok(partition) + } + + /// Get a partition + pub fn get(&self, id: &str) -> Option<&Partition> { + self.partitions.get(id) + } + + /// Get a mutable partition + pub fn get_mut(&mut self, id: &str) -> Option<&mut Partition> { + self.partitions.get_mut(id) + } + + /// Assign a task to a partition + pub fn assign_task(&mut self, task_id: &str, partition_id: &str) -> Result<(), PartitionError> { + if !self.partitions.contains_key(partition_id) { + return Err(PartitionError::NotFound { id: partition_id.into() }); + } + + // Remove from old partition if any + if let Some(old_id) = self.task_partition.get(task_id).cloned() { + if let Some(old_partition) = self.partitions.get_mut(&old_id) { + old_partition.remove_task(task_id); + } + } + + // Add to new partition + self.task_partition.insert(task_id.into(), partition_id.into()); + if let Some(partition) = self.partitions.get_mut(partition_id) { + partition.add_task(task_id.into()); + } + + Ok(()) + } + + /// Get partition for a task + pub fn get_task_partition(&self, task_id: &str) -> Option<&Partition> { + let partition_id = self.task_partition.get(task_id) + .or(self.default_partition.as_ref())?; + self.partitions.get(partition_id) + } + + /// Get mutable partition for a task + pub fn get_task_partition_mut(&mut self, task_id: &str) -> Option<&mut Partition> { + let partition_id = self.task_partition.get(task_id) + .or(self.default_partition.as_ref()) + .cloned(); + partition_id.and_then(|id| self.partitions.get_mut(&id)) + } + + /// Unassign a task from its partition + pub fn unassign_task(&mut self, task_id: &str) { + if let Some(partition_id) = self.task_partition.remove(task_id) { + if let Some(partition) = self.partitions.get_mut(&partition_id) { + partition.remove_task(task_id); + } + } + } + + /// Set default partition + pub fn set_default(&mut self, id: &str) -> Result<(), PartitionError> { + if !self.partitions.contains_key(id) { + return Err(PartitionError::NotFound { id: id.into() }); + } + self.default_partition = Some(id.into()); + Ok(()) + } + + /// Get default partition + pub fn default_partition(&self) -> Option<&Partition> { + self.default_partition.as_ref().and_then(|id| self.partitions.get(id)) + } + + /// List all partition IDs + pub fn partition_ids(&self) -> Vec<&str> { + self.partitions.keys().map(|s| s.as_str()).collect() + } + + /// Get statistics + pub fn stats(&self) -> PartitionStats { + let mut total_memory = 0; + let mut total_compute = 0.0; + let mut total_bandwidth = 0.0; + let mut active = 0; + + for partition in self.partitions.values() { + if partition.enabled { + active += 1; + total_memory += partition.limits.memory_quota; + total_compute += partition.limits.compute_share; + total_bandwidth += partition.limits.bandwidth_share; + } + } + + let effective_total = if self.config.allow_overcommit { + (self.config.total_memory as f64 * self.config.overcommit_ratio) as usize + } else { + self.config.total_memory + }; + + PartitionStats { + partition_count: self.partitions.len(), + active_partitions: active, + total_memory_allocated: total_memory, + total_compute_allocated: total_compute, + total_bandwidth_allocated: total_bandwidth, + available_memory: effective_total.saturating_sub(total_memory), + available_compute: (self.config.total_compute - total_compute).max(0.0), + available_bandwidth: (self.config.total_bandwidth - total_bandwidth).max(0.0), + } + } + + /// Clear all partitions + pub fn clear(&mut self) { + self.partitions.clear(); + self.task_partition.clear(); + self.default_partition = None; + } + + /// Get config + pub fn config(&self) -> &PartitionConfig { + &self.config + } +} + +/// Partition errors +#[derive(Debug, Clone)] +pub enum PartitionError { + /// Partition already exists + AlreadyExists { id: String }, + /// Partition not found + NotFound { id: String }, + /// Insufficient resources + InsufficientResources { + resource: String, + requested: usize, + available: usize, + }, + /// Operation not allowed + NotAllowed { reason: String }, +} + +impl std::fmt::Display for PartitionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PartitionError::AlreadyExists { id } => { + write!(f, "Partition '{}' already exists", id) + } + PartitionError::NotFound { id } => { + write!(f, "Partition '{}' not found", id) + } + PartitionError::InsufficientResources { resource, requested, available } => { + write!(f, "Insufficient {}: requested {}, available {}", resource, requested, available) + } + PartitionError::NotAllowed { reason } => { + write!(f, "Operation not allowed: {}", reason) + } + } + } +} + +impl std::error::Error for PartitionError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_partition_limits() { + let limits = PartitionLimits::with_memory(1024 * 1024 * 1024) + .compute(0.5) + .bandwidth(0.3); + + assert_eq!(limits.memory_quota, 1024 * 1024 * 1024); + assert!((limits.compute_share - 0.5).abs() < 0.001); + assert!((limits.bandwidth_share - 0.3).abs() < 0.001); + } + + #[test] + fn test_partition_creation() { + let mut manager = PartitionManager::with_defaults(); + + let limits = PartitionLimits::with_memory(1024 * 1024 * 1024); + manager.create_partition("p1", "Partition 1", limits).unwrap(); + + assert!(manager.get("p1").is_some()); + assert_eq!(manager.partition_ids().len(), 1); + } + + #[test] + fn test_duplicate_partition() { + let mut manager = PartitionManager::with_defaults(); + + let limits = PartitionLimits::default(); + manager.create_partition("p1", "Partition 1", limits.clone()).unwrap(); + + let result = manager.create_partition("p1", "Duplicate", limits); + assert!(matches!(result, Err(PartitionError::AlreadyExists { .. }))); + } + + #[test] + fn test_task_assignment() { + let mut manager = PartitionManager::with_defaults(); + + let limits = PartitionLimits::default(); + manager.create_partition("p1", "Partition 1", limits).unwrap(); + + manager.assign_task("task-1", "p1").unwrap(); + + let partition = manager.get_task_partition("task-1"); + assert!(partition.is_some()); + assert_eq!(partition.unwrap().id, "p1"); + } + + #[test] + fn test_default_partition() { + let mut manager = PartitionManager::with_defaults(); + + let limits = PartitionLimits::default(); + manager.create_partition("default", "Default", limits).unwrap(); + + // Unassigned task should use default + let partition = manager.get_task_partition("unknown-task"); + assert!(partition.is_some()); + assert_eq!(partition.unwrap().id, "default"); + } + + #[test] + fn test_memory_allocation() { + let mut manager = PartitionManager::with_defaults(); + + let limits = PartitionLimits::with_memory(1000); + manager.create_partition("p1", "Partition 1", limits).unwrap(); + + let partition = manager.get_mut("p1").unwrap(); + + // Should succeed + assert!(partition.allocate_memory(500)); + assert_eq!(partition.usage.memory_used, 500); + + // Should succeed + assert!(partition.allocate_memory(400)); + assert_eq!(partition.usage.memory_used, 900); + + // Should fail (exceeds quota) + assert!(!partition.allocate_memory(200)); + assert_eq!(partition.usage.memory_used, 900); + + // Free and try again + partition.free_memory(500); + assert!(partition.allocate_memory(200)); + } + + #[test] + fn test_kernel_submission() { + let mut manager = PartitionManager::with_defaults(); + + let limits = PartitionLimits::default().streams(2); + manager.create_partition("p1", "Partition 1", limits).unwrap(); + + let partition = manager.get_mut("p1").unwrap(); + partition.limits.max_pending_kernels = 3; + + assert!(partition.submit_kernel()); + assert!(partition.submit_kernel()); + assert!(partition.submit_kernel()); + assert!(!partition.submit_kernel()); // Limit reached + + partition.complete_kernel(1.0); + assert!(partition.submit_kernel()); + assert_eq!(partition.usage.total_kernels, 1); + } + + #[test] + fn test_stream_limits() { + let mut manager = PartitionManager::with_defaults(); + + let limits = PartitionLimits::default().streams(2); + manager.create_partition("p1", "Partition 1", limits).unwrap(); + + let partition = manager.get_mut("p1").unwrap(); + + assert!(partition.register_stream()); + assert!(partition.register_stream()); + assert!(!partition.register_stream()); // Limit + + partition.release_stream(); + assert!(partition.register_stream()); + } + + #[test] + fn test_stats() { + let mut manager = PartitionManager::with_memory(10_000_000_000); // 10GB + + let limits1 = PartitionLimits::with_memory(1_000_000_000).compute(0.3).bandwidth(0.3); + let limits2 = PartitionLimits::with_memory(2_000_000_000).compute(0.4).bandwidth(0.4); + + manager.create_partition("p1", "P1", limits1).unwrap(); + manager.create_partition("p2", "P2", limits2).unwrap(); + + let stats = manager.stats(); + assert_eq!(stats.partition_count, 2); + assert_eq!(stats.total_memory_allocated, 3_000_000_000); + assert!((stats.total_compute_allocated - 0.7).abs() < 0.001); + assert_eq!(stats.available_memory, 7_000_000_000); + } + + #[test] + fn test_delete_partition() { + let mut manager = PartitionManager::with_defaults(); + + // Use partial limits so both partitions can be created + let limits1 = PartitionLimits::with_compute(0.4).bandwidth(0.4); + let limits2 = PartitionLimits::with_compute(0.4).bandwidth(0.4); + manager.create_partition("p1", "P1", limits1).unwrap(); + manager.create_partition("p2", "P2", limits2).unwrap(); + + manager.assign_task("task-1", "p1").unwrap(); + + // Delete p1 + manager.delete_partition("p1").unwrap(); + + assert!(manager.get("p1").is_none()); + assert!(manager.get_task_partition("task-1").is_none() || + manager.get_task_partition("task-1").unwrap().id != "p1"); + } + + #[test] + fn test_reassign_task() { + let mut manager = PartitionManager::with_defaults(); + + // Use partial limits so both partitions can be created + let limits1 = PartitionLimits::with_compute(0.4).bandwidth(0.4); + let limits2 = PartitionLimits::with_compute(0.4).bandwidth(0.4); + manager.create_partition("p1", "P1", limits1).unwrap(); + manager.create_partition("p2", "P2", limits2).unwrap(); + + manager.assign_task("task-1", "p1").unwrap(); + assert!(manager.get("p1").unwrap().has_task("task-1")); + + manager.assign_task("task-1", "p2").unwrap(); + assert!(!manager.get("p1").unwrap().has_task("task-1")); + assert!(manager.get("p2").unwrap().has_task("task-1")); + } +} diff --git a/rust/pygpukit-core/src/scheduler/qos.rs b/rust/pygpukit-core/src/scheduler/qos.rs new file mode 100644 index 0000000..a79cadb --- /dev/null +++ b/rust/pygpukit-core/src/scheduler/qos.rs @@ -0,0 +1,612 @@ +//! QoS (Quality of Service) Policy Framework +//! +//! Provides Kubernetes-style QoS tiers for GPU task scheduling: +//! - Guaranteed: Reserved resources, highest priority +//! - Burstable: Partial reservations, can use spare capacity +//! - BestEffort: No reservations, lowest priority + +use crate::scheduler::task::TaskMeta; + +/// QoS class for GPU tasks (Kubernetes-style) +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub enum QosClass { + /// Guaranteed: Full resource reservation + /// - Memory and bandwidth are fully reserved + /// - Never preempted + /// - Highest scheduling priority + Guaranteed, + + /// Burstable: Partial resource reservation + /// - Memory request is reserved, can burst to limit + /// - May be throttled under contention + /// - Medium scheduling priority + #[default] + Burstable, + + /// BestEffort: No resource reservation + /// - No guaranteed resources + /// - First to be evicted under pressure + /// - Lowest scheduling priority + BestEffort, +} + +impl QosClass { + /// Get scheduling priority for this QoS class + pub fn priority(&self) -> i32 { + match self { + QosClass::Guaranteed => 100, + QosClass::Burstable => 50, + QosClass::BestEffort => 0, + } + } + + /// Check if this class can preempt another + pub fn can_preempt(&self, other: &QosClass) -> bool { + self.priority() > other.priority() + } + + /// Get memory overcommit ratio for this class + pub fn memory_overcommit_ratio(&self) -> f64 { + match self { + QosClass::Guaranteed => 1.0, // No overcommit + QosClass::Burstable => 1.5, // 50% overcommit allowed + QosClass::BestEffort => 2.0, // 100% overcommit allowed + } + } + + /// Get bandwidth allocation ratio for this class + pub fn bandwidth_ratio(&self) -> f64 { + match self { + QosClass::Guaranteed => 1.0, // Full bandwidth + QosClass::Burstable => 0.8, // 80% bandwidth + QosClass::BestEffort => 0.5, // 50% bandwidth + } + } + + /// Check if resources should be reserved for this class + pub fn reserves_resources(&self) -> bool { + match self { + QosClass::Guaranteed => true, + QosClass::Burstable => true, + QosClass::BestEffort => false, + } + } +} + +/// Resource requirements for a task +#[derive(Debug, Clone, Default)] +pub struct ResourceRequirements { + /// Minimum memory required (request) + pub memory_request: usize, + /// Maximum memory allowed (limit) + pub memory_limit: usize, + /// Minimum bandwidth required (0.0 - 1.0) + pub bandwidth_request: f64, + /// Maximum bandwidth allowed (0.0 - 1.0) + pub bandwidth_limit: f64, +} + +impl ResourceRequirements { + /// Create new resource requirements + pub fn new(memory_request: usize, memory_limit: usize) -> Self { + Self { + memory_request, + memory_limit, + bandwidth_request: 0.0, + bandwidth_limit: 1.0, + } + } + + /// Create requirements with just memory limit (request = limit) + pub fn guaranteed(memory: usize) -> Self { + Self { + memory_request: memory, + memory_limit: memory, + bandwidth_request: 0.0, + bandwidth_limit: 1.0, + } + } + + /// Create requirements with request/limit ratio + pub fn burstable(memory_request: usize, burst_ratio: f64) -> Self { + Self { + memory_request, + memory_limit: (memory_request as f64 * burst_ratio) as usize, + bandwidth_request: 0.0, + bandwidth_limit: 1.0, + } + } + + /// Create best-effort requirements (no limits) + pub fn best_effort() -> Self { + Self { + memory_request: 0, + memory_limit: usize::MAX, + bandwidth_request: 0.0, + bandwidth_limit: 1.0, + } + } + + /// Set bandwidth requirements + pub fn with_bandwidth(mut self, request: f64, limit: f64) -> Self { + self.bandwidth_request = request; + self.bandwidth_limit = limit; + self + } +} + +/// QoS policy configuration +#[derive(Debug, Clone)] +pub struct QosPolicy { + /// QoS class + pub class: QosClass, + /// Resource requirements + pub resources: ResourceRequirements, +} + +impl Default for QosPolicy { + fn default() -> Self { + Self { + class: QosClass::Burstable, + resources: ResourceRequirements::default(), + } + } +} + +impl QosPolicy { + /// Create a Guaranteed policy + pub fn guaranteed(memory: usize) -> Self { + Self { + class: QosClass::Guaranteed, + resources: ResourceRequirements::guaranteed(memory), + } + } + + /// Create a Burstable policy + pub fn burstable(memory_request: usize, burst_ratio: f64) -> Self { + Self { + class: QosClass::Burstable, + resources: ResourceRequirements::burstable(memory_request, burst_ratio), + } + } + + /// Create a BestEffort policy + pub fn best_effort() -> Self { + Self { + class: QosClass::BestEffort, + resources: ResourceRequirements::best_effort(), + } + } + + /// Get effective scheduling priority + pub fn effective_priority(&self, base_priority: i32) -> i32 { + base_priority + self.class.priority() + } + + /// Check if this policy can be satisfied with available resources + pub fn can_satisfy(&self, available_memory: usize, available_bandwidth: f64) -> bool { + self.resources.memory_request <= available_memory + && self.resources.bandwidth_request <= available_bandwidth + } + + /// Get amount of memory to reserve + pub fn memory_to_reserve(&self) -> usize { + match self.class { + QosClass::Guaranteed => self.resources.memory_limit, + QosClass::Burstable => self.resources.memory_request, + QosClass::BestEffort => 0, + } + } + + /// Get amount of bandwidth to reserve + pub fn bandwidth_to_reserve(&self) -> f64 { + match self.class { + QosClass::Guaranteed => self.resources.bandwidth_limit, + QosClass::Burstable => self.resources.bandwidth_request, + QosClass::BestEffort => 0.0, + } + } +} + +/// QoS-aware task metadata +#[derive(Debug, Clone)] +pub struct QosTaskMeta { + /// Base task metadata + pub task: TaskMeta, + /// QoS policy + pub qos: QosPolicy, +} + +impl QosTaskMeta { + /// Create a new QoS task + pub fn new(task: TaskMeta, qos: QosPolicy) -> Self { + Self { task, qos } + } + + /// Create a Guaranteed task + pub fn guaranteed(id: String, name: String, memory: usize) -> Self { + let task = TaskMeta::with_memory(id, name, memory); + Self { + task, + qos: QosPolicy::guaranteed(memory), + } + } + + /// Create a Burstable task + pub fn burstable(id: String, name: String, memory_request: usize, burst_ratio: f64) -> Self { + let task = TaskMeta::with_memory(id, name, memory_request); + Self { + task, + qos: QosPolicy::burstable(memory_request, burst_ratio), + } + } + + /// Create a BestEffort task + pub fn best_effort(id: String, name: String) -> Self { + let task = TaskMeta::new(id, name); + Self { + task, + qos: QosPolicy::best_effort(), + } + } + + /// Get effective priority + pub fn effective_priority(&self) -> i32 { + self.qos.effective_priority(self.task.priority) + } +} + +/// QoS evaluation result +#[derive(Debug, Clone)] +pub enum QosEvaluation { + /// Task should be admitted with the given QoS class + Admit { + class: QosClass, + reserved_memory: usize, + reserved_bandwidth: f64, + }, + /// Task should be throttled (Burstable exceeding request) + Throttle { + class: QosClass, + allowed_memory: usize, + allowed_bandwidth: f64, + }, + /// Task should be queued (BestEffort waiting) + Queue { + position: usize, + }, + /// Task should be rejected + Reject { + reason: String, + }, +} + +impl QosEvaluation { + /// Check if admitted + pub fn is_admitted(&self) -> bool { + matches!(self, Self::Admit { .. }) + } + + /// Check if throttled + pub fn is_throttled(&self) -> bool { + matches!(self, Self::Throttle { .. }) + } + + /// Check if queued + pub fn is_queued(&self) -> bool { + matches!(self, Self::Queue { .. }) + } + + /// Check if rejected + pub fn is_rejected(&self) -> bool { + matches!(self, Self::Reject { .. }) + } +} + +/// QoS policy evaluator +#[derive(Debug, Clone)] +pub struct QosPolicyEvaluator { + /// Total system memory + total_memory: usize, + /// Total system bandwidth + total_bandwidth: f64, + /// Reserved memory for Guaranteed tasks + guaranteed_memory: usize, + /// Reserved bandwidth for Guaranteed tasks + guaranteed_bandwidth: f64, + /// Memory used by Burstable tasks + burstable_memory: usize, + /// Best-effort queue count + best_effort_queue: usize, +} + +impl QosPolicyEvaluator { + /// Create a new evaluator + pub fn new(total_memory: usize, total_bandwidth: f64) -> Self { + Self { + total_memory, + total_bandwidth, + guaranteed_memory: 0, + guaranteed_bandwidth: 0.0, + burstable_memory: 0, + best_effort_queue: 0, + } + } + + /// Evaluate QoS policy for a task + pub fn evaluate(&self, qos_task: &QosTaskMeta) -> QosEvaluation { + let policy = &qos_task.qos; + let available_memory = self.total_memory + .saturating_sub(self.guaranteed_memory) + .saturating_sub(self.burstable_memory); + let available_bandwidth = (self.total_bandwidth - self.guaranteed_bandwidth).max(0.0); + + match policy.class { + QosClass::Guaranteed => { + // Guaranteed tasks need full resource reservation + if policy.can_satisfy(available_memory, available_bandwidth) { + QosEvaluation::Admit { + class: QosClass::Guaranteed, + reserved_memory: policy.memory_to_reserve(), + reserved_bandwidth: policy.bandwidth_to_reserve(), + } + } else { + QosEvaluation::Reject { + reason: format!( + "Insufficient resources for Guaranteed task: need {} bytes, {} bandwidth", + policy.resources.memory_request, policy.resources.bandwidth_request + ), + } + } + } + QosClass::Burstable => { + // Burstable tasks can be admitted with throttling + if policy.resources.memory_request <= available_memory { + let allowed_memory = policy.resources.memory_limit.min(available_memory); + let allowed_bandwidth = policy.resources.bandwidth_limit.min(available_bandwidth); + + if allowed_memory < policy.resources.memory_limit + || allowed_bandwidth < policy.resources.bandwidth_limit + { + QosEvaluation::Throttle { + class: QosClass::Burstable, + allowed_memory, + allowed_bandwidth, + } + } else { + QosEvaluation::Admit { + class: QosClass::Burstable, + reserved_memory: policy.memory_to_reserve(), + reserved_bandwidth: policy.bandwidth_to_reserve(), + } + } + } else { + QosEvaluation::Reject { + reason: format!( + "Insufficient memory for Burstable task: need {} bytes", + policy.resources.memory_request + ), + } + } + } + QosClass::BestEffort => { + // BestEffort tasks are queued if resources unavailable + if available_memory > 0 && available_bandwidth > 0.0 { + QosEvaluation::Admit { + class: QosClass::BestEffort, + reserved_memory: 0, + reserved_bandwidth: 0.0, + } + } else { + QosEvaluation::Queue { + position: self.best_effort_queue, + } + } + } + } + } + + /// Reserve resources for an admitted task + pub fn reserve(&mut self, evaluation: &QosEvaluation) { + match evaluation { + QosEvaluation::Admit { + class, + reserved_memory, + reserved_bandwidth, + } => match class { + QosClass::Guaranteed => { + self.guaranteed_memory += reserved_memory; + self.guaranteed_bandwidth += reserved_bandwidth; + } + QosClass::Burstable => { + self.burstable_memory += reserved_memory; + } + QosClass::BestEffort => { + // No reservation for best-effort + } + }, + QosEvaluation::Throttle { .. } => { + // Throttled tasks use limited resources + } + QosEvaluation::Queue { .. } => { + self.best_effort_queue += 1; + } + QosEvaluation::Reject { .. } => {} + } + } + + /// Release resources when a task completes + pub fn release(&mut self, class: QosClass, memory: usize, bandwidth: f64) { + match class { + QosClass::Guaranteed => { + self.guaranteed_memory = self.guaranteed_memory.saturating_sub(memory); + self.guaranteed_bandwidth = (self.guaranteed_bandwidth - bandwidth).max(0.0); + } + QosClass::Burstable => { + self.burstable_memory = self.burstable_memory.saturating_sub(memory); + } + QosClass::BestEffort => { + self.best_effort_queue = self.best_effort_queue.saturating_sub(1); + } + } + } + + /// Get statistics + pub fn stats(&self) -> QosStats { + QosStats { + total_memory: self.total_memory, + total_bandwidth: self.total_bandwidth, + guaranteed_memory: self.guaranteed_memory, + guaranteed_bandwidth: self.guaranteed_bandwidth, + burstable_memory: self.burstable_memory, + best_effort_queue: self.best_effort_queue, + available_memory: self.total_memory + .saturating_sub(self.guaranteed_memory) + .saturating_sub(self.burstable_memory), + available_bandwidth: (self.total_bandwidth - self.guaranteed_bandwidth).max(0.0), + } + } + + /// Reset evaluator state + pub fn reset(&mut self) { + self.guaranteed_memory = 0; + self.guaranteed_bandwidth = 0.0; + self.burstable_memory = 0; + self.best_effort_queue = 0; + } +} + +/// QoS statistics +#[derive(Debug, Clone, Default)] +pub struct QosStats { + /// Total system memory + pub total_memory: usize, + /// Total system bandwidth + pub total_bandwidth: f64, + /// Memory reserved for Guaranteed tasks + pub guaranteed_memory: usize, + /// Bandwidth reserved for Guaranteed tasks + pub guaranteed_bandwidth: f64, + /// Memory used by Burstable tasks + pub burstable_memory: usize, + /// Number of BestEffort tasks in queue + pub best_effort_queue: usize, + /// Available memory + pub available_memory: usize, + /// Available bandwidth + pub available_bandwidth: f64, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_qos_class_priority() { + assert!(QosClass::Guaranteed.priority() > QosClass::Burstable.priority()); + assert!(QosClass::Burstable.priority() > QosClass::BestEffort.priority()); + } + + #[test] + fn test_qos_class_preemption() { + assert!(QosClass::Guaranteed.can_preempt(&QosClass::Burstable)); + assert!(QosClass::Guaranteed.can_preempt(&QosClass::BestEffort)); + assert!(QosClass::Burstable.can_preempt(&QosClass::BestEffort)); + assert!(!QosClass::BestEffort.can_preempt(&QosClass::Guaranteed)); + } + + #[test] + fn test_guaranteed_task() { + let mut evaluator = QosPolicyEvaluator::new(1000, 1.0); + let task = QosTaskMeta::guaranteed("task-1".into(), "Test".into(), 500); + + let eval = evaluator.evaluate(&task); + assert!(eval.is_admitted()); + + if let QosEvaluation::Admit { + class, + reserved_memory, + .. + } = &eval + { + assert_eq!(*class, QosClass::Guaranteed); + assert_eq!(*reserved_memory, 500); + } + + evaluator.reserve(&eval); + assert_eq!(evaluator.stats().guaranteed_memory, 500); + } + + #[test] + fn test_burstable_task() { + let mut evaluator = QosPolicyEvaluator::new(1000, 1.0); + let task = QosTaskMeta::burstable("task-1".into(), "Test".into(), 300, 2.0); + + let eval = evaluator.evaluate(&task); + assert!(eval.is_admitted()); + + evaluator.reserve(&eval); + assert_eq!(evaluator.stats().burstable_memory, 300); + } + + #[test] + fn test_burstable_throttled() { + let mut evaluator = QosPolicyEvaluator::new(500, 1.0); + + // First task uses some memory + let task1 = QosTaskMeta::guaranteed("task-1".into(), "Test".into(), 200); + let eval1 = evaluator.evaluate(&task1); + evaluator.reserve(&eval1); + + // Second task requests 200 but can burst to 400 + let task2 = QosTaskMeta::burstable("task-2".into(), "Test".into(), 200, 2.0); + let eval2 = evaluator.evaluate(&task2); + + // Should be throttled because burst limit (400) exceeds available (300) + assert!(eval2.is_throttled()); + } + + #[test] + fn test_best_effort_task() { + let evaluator = QosPolicyEvaluator::new(1000, 1.0); + let task = QosTaskMeta::best_effort("task-1".into(), "Test".into()); + + let eval = evaluator.evaluate(&task); + assert!(eval.is_admitted()); + + if let QosEvaluation::Admit { + reserved_memory, .. + } = &eval + { + assert_eq!(*reserved_memory, 0); // No reservation + } + } + + #[test] + fn test_best_effort_queued() { + let evaluator = QosPolicyEvaluator::new(0, 0.0); // No resources + let task = QosTaskMeta::best_effort("task-1".into(), "Test".into()); + + let eval = evaluator.evaluate(&task); + assert!(eval.is_queued()); + } + + #[test] + fn test_guaranteed_reject() { + let evaluator = QosPolicyEvaluator::new(100, 1.0); + let task = QosTaskMeta::guaranteed("task-1".into(), "Test".into(), 500); + + let eval = evaluator.evaluate(&task); + assert!(eval.is_rejected()); + } + + #[test] + fn test_effective_priority() { + let guaranteed = QosTaskMeta::guaranteed("g".into(), "G".into(), 100); + let burstable = QosTaskMeta::burstable("b".into(), "B".into(), 100, 1.5); + let best_effort = QosTaskMeta::best_effort("e".into(), "E".into()); + + assert!(guaranteed.effective_priority() > burstable.effective_priority()); + assert!(burstable.effective_priority() > best_effort.effective_priority()); + } +} diff --git a/rust/pygpukit-core/src/transfer/mod.rs b/rust/pygpukit-core/src/transfer/mod.rs index 15f9ec1..4a83975 100644 --- a/rust/pygpukit-core/src/transfer/mod.rs +++ b/rust/pygpukit-core/src/transfer/mod.rs @@ -5,12 +5,17 @@ //! - H2D, D2H, and D2D transfer types //! - Stream synchronization model //! - Integration with the scheduler tick loop +//! - Pinned memory management for fast transfers //! //! Note: This module provides the Rust-side coordination logic. //! Actual CUDA stream operations are handled by the C++ backend via callbacks. mod operation; mod engine; +mod pinned; pub use operation::{TransferType, TransferOp, TransferState}; pub use engine::{AsyncTransferEngine, StreamType, TransferStats, TransferCallback}; +pub use pinned::{ + PinnedMemoryManager, PinnedPoolConfig, PinnedBlock, PinnedStats, PinnedError, +}; diff --git a/rust/pygpukit-core/src/transfer/pinned.rs b/rust/pygpukit-core/src/transfer/pinned.rs new file mode 100644 index 0000000..e6b9a24 --- /dev/null +++ b/rust/pygpukit-core/src/transfer/pinned.rs @@ -0,0 +1,628 @@ +//! Pinned Memory Manager +//! +//! Manages page-locked (pinned) host memory for faster CPU-GPU transfers. +//! Pinned memory enables DMA transfers and can significantly improve +//! transfer bandwidth compared to pageable memory. + +use std::collections::HashMap; + +/// Pinned memory block information +#[derive(Debug, Clone)] +pub struct PinnedBlock { + /// Block ID + pub id: u64, + /// Host pointer (virtual address) + pub host_ptr: u64, + /// Size in bytes + pub size: usize, + /// Whether the block is currently in use + pub in_use: bool, + /// Associated task ID + pub task_id: Option, + /// Allocation timestamp + pub allocated_at: f64, + /// Last access timestamp + pub last_access: f64, +} + +impl PinnedBlock { + /// Create a new pinned block + pub fn new(id: u64, host_ptr: u64, size: usize) -> Self { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0); + Self { + id, + host_ptr, + size, + in_use: true, + task_id: None, + allocated_at: now, + last_access: now, + } + } + + /// Associate with a task + pub fn with_task(mut self, task_id: String) -> Self { + self.task_id = Some(task_id); + self + } + + /// Touch to update last access time + pub fn touch(&mut self) { + self.last_access = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0); + } +} + +/// Pinned memory pool configuration +#[derive(Debug, Clone)] +pub struct PinnedPoolConfig { + /// Maximum total pinned memory in bytes + pub max_size: usize, + /// Enable pooling for reuse + pub enable_pooling: bool, + /// Size classes for pooling (in bytes) + pub size_classes: Vec, + /// Default alignment + pub alignment: usize, +} + +impl Default for PinnedPoolConfig { + fn default() -> Self { + Self { + max_size: 1024 * 1024 * 1024, // 1GB default + enable_pooling: true, + size_classes: vec![ + 4096, // 4KB + 65536, // 64KB + 262144, // 256KB + 1048576, // 1MB + 4194304, // 4MB + 16777216, // 16MB + 67108864, // 64MB + 268435456, // 256MB + ], + alignment: 256, // CUDA alignment + } + } +} + +impl PinnedPoolConfig { + /// Create with max size + pub fn with_max_size(max_size: usize) -> Self { + Self { + max_size, + ..Default::default() + } + } + + /// Set pooling enabled + pub fn enable_pooling(mut self, enable: bool) -> Self { + self.enable_pooling = enable; + self + } + + /// Get size class for a given size + pub fn get_size_class(&self, size: usize) -> usize { + for &class_size in &self.size_classes { + if size <= class_size { + return class_size; + } + } + // Round up to alignment + let aligned = (size + self.alignment - 1) / self.alignment * self.alignment; + aligned + } +} + +/// Pinned memory manager statistics +#[derive(Debug, Clone, Default)] +pub struct PinnedStats { + /// Total bytes allocated + pub total_allocated: usize, + /// Current bytes in use + pub current_used: usize, + /// Peak bytes used + pub peak_used: usize, + /// Total allocations + pub total_allocations: usize, + /// Total frees + pub total_frees: usize, + /// Pool hits (reused from pool) + pub pool_hits: usize, + /// Pool misses (new allocation) + pub pool_misses: usize, + /// Current pool size + pub pool_size: usize, + /// Blocks currently pooled + pub pooled_blocks: usize, +} + +/// Pinned memory manager +/// +/// Manages pinned host memory allocations with optional pooling +/// for efficient reuse. +#[derive(Debug)] +pub struct PinnedMemoryManager { + config: PinnedPoolConfig, + /// Active allocations by ID + active: HashMap, + /// Free pool by size class + pool: HashMap>, + /// Next block ID + next_id: u64, + /// Statistics + total_allocated: usize, + current_used: usize, + peak_used: usize, + total_allocations: usize, + total_frees: usize, + pool_hits: usize, + pool_misses: usize, +} + +impl PinnedMemoryManager { + /// Create a new pinned memory manager + pub fn new(config: PinnedPoolConfig) -> Self { + Self { + config, + active: HashMap::new(), + pool: HashMap::new(), + next_id: 1, + total_allocated: 0, + current_used: 0, + peak_used: 0, + total_allocations: 0, + total_frees: 0, + pool_hits: 0, + pool_misses: 0, + } + } + + /// Create with default config + pub fn with_defaults() -> Self { + Self::new(PinnedPoolConfig::default()) + } + + /// Create with max size + pub fn with_max_size(max_size: usize) -> Self { + Self::new(PinnedPoolConfig::with_max_size(max_size)) + } + + /// Check if allocation would exceed quota + pub fn can_allocate(&self, size: usize) -> bool { + // If pooling is enabled and we have a suitable block, always allow + if self.config.enable_pooling { + let size_class = self.config.get_size_class(size); + if let Some(blocks) = self.pool.get(&size_class) { + if !blocks.is_empty() { + return true; + } + } + } + // Check quota + self.current_used + size <= self.config.max_size + } + + /// Allocate pinned memory + /// + /// If pooling is enabled, may return a previously freed block. + /// Returns the block ID and size class (actual allocated size). + /// + /// The `host_ptr` should be provided by the caller after performing + /// the actual cudaHostAlloc call if this is a new allocation. + pub fn allocate(&mut self, size: usize) -> Result<(u64, usize, bool), PinnedError> { + let size_class = self.config.get_size_class(size); + + // Try to get from pool first + if self.config.enable_pooling { + if let Some(blocks) = self.pool.get_mut(&size_class) { + if let Some(mut block) = blocks.pop() { + block.in_use = true; + block.touch(); + let id = block.id; + self.active.insert(id, block); + self.current_used += size_class; + self.peak_used = self.peak_used.max(self.current_used); + self.total_allocations += 1; + self.pool_hits += 1; + // Return (id, size_class, reused=true) + return Ok((id, size_class, true)); + } + } + } + + // Check quota + if self.current_used + size_class > self.config.max_size { + return Err(PinnedError::QuotaExceeded { + requested: size_class, + available: self.config.max_size.saturating_sub(self.current_used), + }); + } + + // Need new allocation + let id = self.next_id; + self.next_id += 1; + self.total_allocations += 1; + self.total_allocated += size_class; + self.current_used += size_class; + self.peak_used = self.peak_used.max(self.current_used); + self.pool_misses += 1; + + // Return (id, size_class, reused=false) - caller must allocate + Ok((id, size_class, false)) + } + + /// Register an allocated block + /// + /// Called after cudaHostAlloc succeeds with the actual pointer. + pub fn register(&mut self, id: u64, host_ptr: u64, size: usize) { + let size_class = self.config.get_size_class(size); + let block = PinnedBlock::new(id, host_ptr, size_class); + self.active.insert(id, block); + } + + /// Free a pinned block + /// + /// If pooling is enabled, the block is returned to the pool. + /// Returns (should_free, host_ptr) - if should_free is true, + /// the caller should call cudaFreeHost. + pub fn free(&mut self, id: u64) -> Result<(bool, u64), PinnedError> { + let block = self.active.remove(&id) + .ok_or(PinnedError::InvalidBlock { id })?; + + self.total_frees += 1; + self.current_used = self.current_used.saturating_sub(block.size); + + // Return to pool if enabled + if self.config.enable_pooling { + let size_class = block.size; + let mut pooled_block = block.clone(); + pooled_block.in_use = false; + pooled_block.task_id = None; + + self.pool + .entry(size_class) + .or_insert_with(Vec::new) + .push(pooled_block); + + // Don't free, return to pool + Ok((false, block.host_ptr)) + } else { + // Free immediately + Ok((true, block.host_ptr)) + } + } + + /// Associate a block with a task + pub fn associate_task(&mut self, id: u64, task_id: String) -> Result<(), PinnedError> { + let block = self.active.get_mut(&id) + .ok_or(PinnedError::InvalidBlock { id })?; + block.task_id = Some(task_id); + Ok(()) + } + + /// Get a block by ID + pub fn get(&self, id: u64) -> Option<&PinnedBlock> { + self.active.get(&id) + } + + /// Touch a block to update access time + pub fn touch(&mut self, id: u64) -> Result<(), PinnedError> { + let block = self.active.get_mut(&id) + .ok_or(PinnedError::InvalidBlock { id })?; + block.touch(); + Ok(()) + } + + /// Get blocks for a task + pub fn get_blocks_for_task(&self, task_id: &str) -> Vec<&PinnedBlock> { + self.active.values() + .filter(|b| b.task_id.as_deref() == Some(task_id)) + .collect() + } + + /// Free all blocks for a task + /// + /// Returns list of (should_free, host_ptr) for blocks to potentially free. + pub fn free_task_blocks(&mut self, task_id: &str) -> Vec<(bool, u64)> { + let ids: Vec = self.active.values() + .filter(|b| b.task_id.as_deref() == Some(task_id)) + .map(|b| b.id) + .collect(); + + ids.into_iter() + .filter_map(|id| self.free(id).ok()) + .collect() + } + + /// Get statistics + pub fn stats(&self) -> PinnedStats { + let pool_size: usize = self.pool.values() + .flat_map(|v| v.iter()) + .map(|b| b.size) + .sum(); + let pooled_blocks: usize = self.pool.values() + .map(|v| v.len()) + .sum(); + + PinnedStats { + total_allocated: self.total_allocated, + current_used: self.current_used, + peak_used: self.peak_used, + total_allocations: self.total_allocations, + total_frees: self.total_frees, + pool_hits: self.pool_hits, + pool_misses: self.pool_misses, + pool_size, + pooled_blocks, + } + } + + /// Clear the pool (free all pooled blocks) + /// + /// Returns list of host pointers to free. + pub fn clear_pool(&mut self) -> Vec { + let mut ptrs = Vec::new(); + for blocks in self.pool.values() { + for block in blocks { + ptrs.push(block.host_ptr); + } + } + self.pool.clear(); + ptrs + } + + /// Clear all state (active + pool) + /// + /// Returns list of all host pointers to free. + pub fn clear(&mut self) -> Vec { + let mut ptrs: Vec = self.active.values() + .map(|b| b.host_ptr) + .collect(); + + ptrs.extend(self.clear_pool()); + + self.active.clear(); + self.current_used = 0; + + ptrs + } + + /// Get config + pub fn config(&self) -> &PinnedPoolConfig { + &self.config + } + + /// Get number of active blocks + pub fn active_count(&self) -> usize { + self.active.len() + } +} + +/// Pinned memory errors +#[derive(Debug, Clone)] +pub enum PinnedError { + /// Quota exceeded + QuotaExceeded { + requested: usize, + available: usize, + }, + /// Invalid block ID + InvalidBlock { + id: u64, + }, + /// Allocation failed + AllocationFailed { + reason: String, + }, +} + +impl std::fmt::Display for PinnedError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PinnedError::QuotaExceeded { requested, available } => { + write!(f, "Pinned memory quota exceeded: requested {} bytes, {} available", requested, available) + } + PinnedError::InvalidBlock { id } => { + write!(f, "Invalid pinned block ID: {}", id) + } + PinnedError::AllocationFailed { reason } => { + write!(f, "Pinned memory allocation failed: {}", reason) + } + } + } +} + +impl std::error::Error for PinnedError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pinned_config() { + let config = PinnedPoolConfig::with_max_size(1024 * 1024) + .enable_pooling(true); + + assert_eq!(config.max_size, 1024 * 1024); + assert!(config.enable_pooling); + } + + #[test] + fn test_size_class() { + let config = PinnedPoolConfig::default(); + + // Small allocation should use 4KB class + assert_eq!(config.get_size_class(100), 4096); + + // 10KB should use 64KB class + assert_eq!(config.get_size_class(10000), 65536); + + // 100KB should use 256KB class + assert_eq!(config.get_size_class(100_000), 262144); + } + + #[test] + fn test_allocate_new() { + let mut manager = PinnedMemoryManager::with_max_size(1024 * 1024); + + // Allocate 1KB + let result = manager.allocate(1024); + assert!(result.is_ok()); + + let (id, size_class, reused) = result.unwrap(); + assert_eq!(id, 1); + assert_eq!(size_class, 4096); // Rounded to 4KB + assert!(!reused); + + // Register the block + manager.register(id, 0x1000, 4096); + assert_eq!(manager.active_count(), 1); + } + + #[test] + fn test_pool_reuse() { + let mut manager = PinnedMemoryManager::with_max_size(1024 * 1024); + + // Allocate and register + let (id1, size_class, _) = manager.allocate(1024).unwrap(); + manager.register(id1, 0x1000, size_class); + + // Free to pool + let (should_free, _) = manager.free(id1).unwrap(); + assert!(!should_free); // Pooled, not freed + + // Allocate again - should reuse from pool + let (id2, _, reused) = manager.allocate(1024).unwrap(); + assert!(reused); + assert_eq!(manager.stats().pool_hits, 1); + + // Same ID since reused from pool + let block = manager.get(id2).unwrap(); + assert_eq!(block.host_ptr, 0x1000); + } + + #[test] + fn test_quota_exceeded() { + let mut manager = PinnedMemoryManager::with_max_size(4096); // Only 4KB + + // First allocation should succeed + let result1 = manager.allocate(1024); + assert!(result1.is_ok()); + let (id1, size_class, _) = result1.unwrap(); + manager.register(id1, 0x1000, size_class); + + // Second allocation should fail (quota exceeded) + let result2 = manager.allocate(1024); + assert!(matches!(result2, Err(PinnedError::QuotaExceeded { .. }))); + } + + #[test] + fn test_task_association() { + let mut manager = PinnedMemoryManager::with_max_size(1024 * 1024); + + let (id, size_class, _) = manager.allocate(1024).unwrap(); + manager.register(id, 0x1000, size_class); + + manager.associate_task(id, "task-1".into()).unwrap(); + + let blocks = manager.get_blocks_for_task("task-1"); + assert_eq!(blocks.len(), 1); + assert_eq!(blocks[0].id, id); + } + + #[test] + fn test_free_task_blocks() { + let mut manager = PinnedMemoryManager::with_max_size(1024 * 1024); + + // Allocate two blocks for same task + let (id1, sc1, _) = manager.allocate(1024).unwrap(); + manager.register(id1, 0x1000, sc1); + manager.associate_task(id1, "task-1".into()).unwrap(); + + let (id2, sc2, _) = manager.allocate(2048).unwrap(); + manager.register(id2, 0x2000, sc2); + manager.associate_task(id2, "task-1".into()).unwrap(); + + // Allocate one for different task + let (id3, sc3, _) = manager.allocate(1024).unwrap(); + manager.register(id3, 0x3000, sc3); + manager.associate_task(id3, "task-2".into()).unwrap(); + + assert_eq!(manager.active_count(), 3); + + // Free task-1 blocks + let freed = manager.free_task_blocks("task-1"); + assert_eq!(freed.len(), 2); + + // Only task-2 block remains active + assert_eq!(manager.active_count(), 1); + } + + #[test] + fn test_stats() { + let mut manager = PinnedMemoryManager::with_max_size(1024 * 1024); + + let (id1, sc1, _) = manager.allocate(1024).unwrap(); + manager.register(id1, 0x1000, sc1); + + let (id2, sc2, _) = manager.allocate(2048).unwrap(); + manager.register(id2, 0x2000, sc2); + + let stats = manager.stats(); + assert_eq!(stats.total_allocations, 2); + assert_eq!(stats.pool_misses, 2); + assert_eq!(stats.pool_hits, 0); + + // Free and reallocate + manager.free(id1).unwrap(); + manager.allocate(512).unwrap(); + + let stats2 = manager.stats(); + assert_eq!(stats2.total_allocations, 3); + assert_eq!(stats2.pool_hits, 1); + } + + #[test] + fn test_clear() { + let mut manager = PinnedMemoryManager::with_max_size(1024 * 1024); + + let (id1, sc1, _) = manager.allocate(1024).unwrap(); + manager.register(id1, 0x1000, sc1); + + let (id2, sc2, _) = manager.allocate(2048).unwrap(); + manager.register(id2, 0x2000, sc2); + + manager.free(id1).unwrap(); + + // Clear returns all pointers + let ptrs = manager.clear(); + assert_eq!(ptrs.len(), 2); // 1 active + 1 pooled + + assert_eq!(manager.active_count(), 0); + assert_eq!(manager.stats().current_used, 0); + } + + #[test] + fn test_no_pooling() { + let config = PinnedPoolConfig::with_max_size(1024 * 1024) + .enable_pooling(false); + let mut manager = PinnedMemoryManager::new(config); + + let (id, size_class, _) = manager.allocate(1024).unwrap(); + manager.register(id, 0x1000, size_class); + + // Free should indicate actual free needed + let (should_free, _) = manager.free(id).unwrap(); + assert!(should_free); // Not pooled + + // Next allocation is not reused + let (_, _, reused) = manager.allocate(1024).unwrap(); + assert!(!reused); + } +} diff --git a/rust/pygpukit-python/src/dispatch.rs b/rust/pygpukit-python/src/dispatch.rs index 09e5c18..24251db 100644 --- a/rust/pygpukit-python/src/dispatch.rs +++ b/rust/pygpukit-python/src/dispatch.rs @@ -4,6 +4,9 @@ use pyo3::prelude::*; use std::collections::HashMap; use pygpukit_core::dispatch::{ KernelDispatcher, KernelLaunchRequest, KernelState, DispatchStats, LaunchConfig, + KernelPacingEngine, PacingConfig, PacingDecision, PacingStats, StreamPacingStats, + SliceScheduler, SliceConfig, SlicedKernel, KernelSlice, SliceInfo, SliceStats, + KernelCache, CacheConfig, CachedKernel, CompileOptions, CacheStats, }; /// Python wrapper for KernelState enum @@ -406,6 +409,943 @@ impl PyKernelDispatcher { } } +// ============================================================================= +// Kernel Pacing Types +// ============================================================================= + +/// Pacing configuration for Python +#[pyclass(name = "PacingConfig")] +#[derive(Clone)] +pub struct PyPacingConfig { + inner: PacingConfig, +} + +#[pymethods] +impl PyPacingConfig { + #[new] + #[pyo3(signature = (total_bandwidth=1.0, window_ms=100.0, min_interval_ms=0.1, adaptive=true))] + fn new(total_bandwidth: f64, window_ms: f64, min_interval_ms: f64, adaptive: bool) -> Self { + Self { + inner: PacingConfig { + total_bandwidth, + window_ms, + min_interval_ms, + adaptive, + }, + } + } + + #[getter] + fn total_bandwidth(&self) -> f64 { + self.inner.total_bandwidth + } + + #[getter] + fn window_ms(&self) -> f64 { + self.inner.window_ms + } + + #[getter] + fn min_interval_ms(&self) -> f64 { + self.inner.min_interval_ms + } + + #[getter] + fn adaptive(&self) -> bool { + self.inner.adaptive + } + + fn __repr__(&self) -> String { + format!( + "PacingConfig(bandwidth={:.2}, window={}ms, min_interval={}ms)", + self.inner.total_bandwidth, self.inner.window_ms, self.inner.min_interval_ms + ) + } +} + +/// Pacing decision for Python +#[pyclass(name = "PacingDecision")] +#[derive(Clone)] +pub struct PyPacingDecision { + inner: PacingDecision, +} + +#[pymethods] +impl PyPacingDecision { + /// Check if immediate launch is allowed + fn can_launch(&self) -> bool { + self.inner.can_launch() + } + + /// Check if throttled + fn is_throttled(&self) -> bool { + self.inner.is_throttled() + } + + /// Get wait time in milliseconds + fn wait_ms(&self) -> f64 { + self.inner.wait_ms() + } + + #[getter] + fn decision_type(&self) -> String { + match &self.inner { + PacingDecision::Launch => "Launch".into(), + PacingDecision::Wait { .. } => "Wait".into(), + PacingDecision::Throttle { .. } => "Throttle".into(), + } + } + + fn __repr__(&self) -> String { + match &self.inner { + PacingDecision::Launch => "PacingDecision(Launch)".into(), + PacingDecision::Wait { delay_ms } => format!("PacingDecision(Wait, delay={:.2}ms)", delay_ms), + PacingDecision::Throttle { reason } => format!("PacingDecision(Throttle, reason='{}')", reason), + } + } +} + +/// Stream pacing statistics for Python +#[pyclass(name = "StreamPacingStats")] +#[derive(Clone)] +pub struct PyStreamPacingStats { + inner: StreamPacingStats, +} + +#[pymethods] +impl PyStreamPacingStats { + #[getter] + fn stream_id(&self) -> u64 { + self.inner.stream_id + } + + #[getter] + fn bandwidth(&self) -> f64 { + self.inner.bandwidth + } + + #[getter] + fn launches_in_window(&self) -> usize { + self.inner.launches_in_window + } + + #[getter] + fn total_launches(&self) -> usize { + self.inner.total_launches + } + + #[getter] + fn throttled_count(&self) -> usize { + self.inner.throttled_count + } + + fn __repr__(&self) -> String { + format!( + "StreamPacingStats(stream={}, bandwidth={:.2}, launches={})", + self.inner.stream_id, self.inner.bandwidth, self.inner.total_launches + ) + } +} + +/// Global pacing statistics for Python +#[pyclass(name = "PacingStats")] +#[derive(Clone)] +pub struct PyPacingStats { + inner: PacingStats, +} + +#[pymethods] +impl PyPacingStats { + #[getter] + fn stream_count(&self) -> usize { + self.inner.stream_count + } + + #[getter] + fn used_bandwidth(&self) -> f64 { + self.inner.used_bandwidth + } + + #[getter] + fn available_bandwidth(&self) -> f64 { + self.inner.available_bandwidth + } + + #[getter] + fn total_launches(&self) -> usize { + self.inner.total_launches + } + + #[getter] + fn total_throttled(&self) -> usize { + self.inner.total_throttled + } + + #[getter] + fn total_waited(&self) -> usize { + self.inner.total_waited + } + + fn __repr__(&self) -> String { + format!( + "PacingStats(streams={}, launches={}, throttled={})", + self.inner.stream_count, self.inner.total_launches, self.inner.total_throttled + ) + } +} + +/// Kernel pacing engine for Python +#[pyclass(name = "KernelPacingEngine")] +pub struct PyKernelPacingEngine { + inner: KernelPacingEngine, +} + +#[pymethods] +impl PyKernelPacingEngine { + #[new] + #[pyo3(signature = (config=None))] + fn new(config: Option) -> Self { + let cfg = config.map(|c| c.inner).unwrap_or_default(); + Self { + inner: KernelPacingEngine::new(cfg), + } + } + + /// Allocate bandwidth for a stream + fn allocate_stream(&mut self, stream_id: u64, bandwidth: f64) -> bool { + self.inner.allocate_stream(stream_id, bandwidth) + } + + /// Release bandwidth for a stream + fn release_stream(&mut self, stream_id: u64) { + self.inner.release_stream(stream_id); + } + + /// Check if a kernel launch should proceed + fn should_launch(&self, stream_id: u64) -> PyPacingDecision { + PyPacingDecision { + inner: self.inner.should_launch(stream_id), + } + } + + /// Record a kernel launch + fn record_launch(&mut self, stream_id: u64) { + self.inner.record_launch(stream_id); + } + + /// Record a throttled request + fn record_throttle(&mut self, stream_id: u64) { + self.inner.record_throttle(stream_id); + } + + /// Record a waited request + fn record_wait(&mut self) { + self.inner.record_wait(); + } + + /// Get stream statistics + fn stream_stats(&self, stream_id: u64) -> Option { + self.inner.stream_stats(stream_id).map(|s| PyStreamPacingStats { inner: s }) + } + + /// Get global statistics + fn stats(&self) -> PyPacingStats { + PyPacingStats { + inner: self.inner.stats(), + } + } + + /// Reset all pacing state + fn reset(&mut self) { + self.inner.reset(); + } + + fn __repr__(&self) -> String { + let stats = self.inner.stats(); + format!( + "KernelPacingEngine(streams={}, available_bw={:.2})", + stats.stream_count, stats.available_bandwidth + ) + } +} + +// ============================================================================= +// Micro-Slicing Types +// ============================================================================= + +/// Slice configuration for Python +#[pyclass(name = "SliceConfig")] +#[derive(Clone)] +pub struct PySliceConfig { + inner: SliceConfig, +} + +#[pymethods] +impl PySliceConfig { + #[new] + #[pyo3(signature = (max_items_per_slice=65536, max_duration_ms=1.0, min_slices=1, max_slices=256, adaptive=true))] + fn new( + max_items_per_slice: usize, + max_duration_ms: f64, + min_slices: usize, + max_slices: usize, + adaptive: bool, + ) -> Self { + Self { + inner: SliceConfig { + max_items_per_slice, + max_duration_ms, + min_slices, + max_slices, + adaptive, + }, + } + } + + #[getter] + fn max_items_per_slice(&self) -> usize { + self.inner.max_items_per_slice + } + + #[getter] + fn max_duration_ms(&self) -> f64 { + self.inner.max_duration_ms + } + + #[getter] + fn min_slices(&self) -> usize { + self.inner.min_slices + } + + #[getter] + fn max_slices(&self) -> usize { + self.inner.max_slices + } + + #[getter] + fn adaptive(&self) -> bool { + self.inner.adaptive + } + + fn __repr__(&self) -> String { + format!( + "SliceConfig(max_items={}, max_duration={}ms, slices=[{}, {}])", + self.inner.max_items_per_slice, + self.inner.max_duration_ms, + self.inner.min_slices, + self.inner.max_slices + ) + } +} + +/// Single kernel slice for Python +#[pyclass(name = "KernelSlice")] +#[derive(Clone)] +pub struct PyKernelSlice { + inner: KernelSlice, +} + +#[pymethods] +impl PyKernelSlice { + #[getter] + fn id(&self) -> usize { + self.inner.id + } + + #[getter] + fn offset(&self) -> usize { + self.inner.offset + } + + #[getter] + fn count(&self) -> usize { + self.inner.count + } + + #[getter] + fn grid(&self) -> (u32, u32, u32) { + self.inner.grid + } + + #[getter] + fn executed(&self) -> bool { + self.inner.executed + } + + #[getter] + fn exec_time_ms(&self) -> Option { + self.inner.exec_time_ms + } + + fn __repr__(&self) -> String { + format!( + "KernelSlice(id={}, offset={}, count={}, executed={})", + self.inner.id, self.inner.offset, self.inner.count, self.inner.executed + ) + } +} + +/// Information about a slice to execute for Python +#[pyclass(name = "SliceInfo")] +#[derive(Clone)] +pub struct PySliceInfo { + inner: SliceInfo, +} + +#[pymethods] +impl PySliceInfo { + #[getter] + fn kernel_handle(&self) -> u64 { + self.inner.kernel_handle + } + + #[getter] + fn block(&self) -> (u32, u32, u32) { + self.inner.block + } + + #[getter] + fn shared_mem(&self) -> u32 { + self.inner.shared_mem + } + + #[getter] + fn slice_id(&self) -> usize { + self.inner.slice_id + } + + #[getter] + fn offset(&self) -> usize { + self.inner.offset + } + + #[getter] + fn count(&self) -> usize { + self.inner.count + } + + #[getter] + fn grid(&self) -> (u32, u32, u32) { + self.inner.grid + } + + #[getter] + fn task_id(&self) -> Option { + self.inner.task_id.clone() + } + + #[getter] + fn priority(&self) -> i32 { + self.inner.priority + } + + fn __repr__(&self) -> String { + format!( + "SliceInfo(kernel=0x{:x}, slice_id={}, offset={}, count={})", + self.inner.kernel_handle, self.inner.slice_id, self.inner.offset, self.inner.count + ) + } +} + +/// Slice scheduler statistics for Python +#[pyclass(name = "SliceStats")] +#[derive(Clone)] +pub struct PySliceStats { + inner: SliceStats, +} + +#[pymethods] +impl PySliceStats { + #[getter] + fn total_slices(&self) -> usize { + self.inner.total_slices + } + + #[getter] + fn completed_slices(&self) -> usize { + self.inner.completed_slices + } + + #[getter] + fn pending_slices(&self) -> usize { + self.inner.pending_slices + } + + #[getter] + fn total_kernels(&self) -> usize { + self.inner.total_kernels + } + + #[getter] + fn completed_kernels(&self) -> usize { + self.inner.completed_kernels + } + + #[getter] + fn pending_kernels(&self) -> usize { + self.inner.pending_kernels + } + + fn __repr__(&self) -> String { + format!( + "SliceStats(slices={}/{}, kernels={}/{})", + self.inner.completed_slices, self.inner.total_slices, + self.inner.completed_kernels, self.inner.total_kernels + ) + } +} + +/// Slice scheduler for Python +/// +/// Splits kernels into smaller slices for fair scheduling +/// and better latency under QoS constraints. +#[pyclass(name = "SliceScheduler")] +pub struct PySliceScheduler { + inner: SliceScheduler, +} + +#[pymethods] +impl PySliceScheduler { + #[new] + #[pyo3(signature = (config=None))] + fn new(config: Option) -> Self { + let cfg = config.map(|c| c.inner).unwrap_or_default(); + Self { + inner: SliceScheduler::new(cfg), + } + } + + /// Submit a kernel for slicing + /// + /// Args: + /// kernel_handle: CUfunction handle as int + /// total_items: Total work items to process + /// block: Block dimensions (x, y, z) + /// shared_mem: Shared memory per block + /// + /// Returns: + /// Number of slices created + fn submit( + &mut self, + kernel_handle: u64, + total_items: usize, + block: (u32, u32, u32), + shared_mem: u32, + ) -> usize { + self.inner.submit(kernel_handle, total_items, block, shared_mem) + } + + /// Submit a kernel for a specific task + /// + /// Args: + /// task_id: Associated task ID + /// kernel_handle: CUfunction handle as int + /// total_items: Total work items to process + /// block: Block dimensions (x, y, z) + /// shared_mem: Shared memory per block + /// priority: Priority (higher = more important) + /// + /// Returns: + /// Number of slices created + fn submit_for_task( + &mut self, + task_id: String, + kernel_handle: u64, + total_items: usize, + block: (u32, u32, u32), + shared_mem: u32, + priority: i32, + ) -> usize { + self.inner.submit_for_task(task_id, kernel_handle, total_items, block, shared_mem, priority) + } + + /// Get next slice to execute (round-robin fair scheduling) + fn get_next_slice(&mut self) -> Option { + self.inner.get_next_slice().map(|s| PySliceInfo { inner: s }) + } + + /// Complete the current slice + fn complete_slice(&mut self, exec_time_ms: f64) { + self.inner.complete_slice(exec_time_ms); + } + + /// Get number of pending slices + fn pending_slices(&self) -> usize { + self.inner.pending_slices() + } + + /// Get number of pending kernels + fn pending_kernels(&self) -> usize { + self.inner.pending_kernels() + } + + /// Get statistics + fn stats(&self) -> PySliceStats { + PySliceStats { + inner: self.inner.stats(), + } + } + + /// Clear all state + fn clear(&mut self) { + self.inner.clear(); + } + + /// Get configuration + fn config(&self) -> PySliceConfig { + PySliceConfig { + inner: self.inner.config().clone(), + } + } + + fn __repr__(&self) -> String { + let stats = self.inner.stats(); + format!( + "SliceScheduler(pending_slices={}, pending_kernels={})", + stats.pending_slices, stats.pending_kernels + ) + } +} + +// ============================================================================= +// Kernel Cache Types +// ============================================================================= + +/// Compile options for Python +#[pyclass(name = "CompileOptions")] +#[derive(Clone)] +pub struct PyCompileOptions { + inner: CompileOptions, +} + +#[pymethods] +impl PyCompileOptions { + #[new] + #[pyo3(signature = (compute_capability="sm_75"))] + fn new(compute_capability: &str) -> Self { + Self { + inner: CompileOptions::with_compute(compute_capability), + } + } + + /// Add a compiler flag + fn flag(&self, flag: &str) -> Self { + Self { + inner: self.inner.clone().flag(flag), + } + } + + /// Add a define macro + fn define(&self, name: &str, value: &str) -> Self { + Self { + inner: self.inner.clone().define(name, value), + } + } + + /// Add an include path + fn include(&self, path: &str) -> Self { + Self { + inner: self.inner.clone().include(path), + } + } + + #[getter] + fn compute_capability(&self) -> &str { + &self.inner.compute_capability + } + + #[getter] + fn flags(&self) -> Vec { + self.inner.flags.clone() + } + + #[getter] + fn defines(&self) -> Vec<(String, String)> { + self.inner.defines.clone() + } + + fn __repr__(&self) -> String { + format!( + "CompileOptions(compute='{}', flags={:?})", + self.inner.compute_capability, self.inner.flags + ) + } +} + +/// Cache configuration for Python +#[pyclass(name = "CacheConfig")] +#[derive(Clone)] +pub struct PyCacheConfig { + inner: CacheConfig, +} + +#[pymethods] +impl PyCacheConfig { + #[new] + #[pyo3(signature = (max_entries=1024, max_ptx_size=268435456, enable_eviction=true, ttl_seconds=0.0))] + fn new(max_entries: usize, max_ptx_size: usize, enable_eviction: bool, ttl_seconds: f64) -> Self { + Self { + inner: CacheConfig { + max_entries, + max_ptx_size, + enable_eviction, + ttl_seconds, + }, + } + } + + #[getter] + fn max_entries(&self) -> usize { + self.inner.max_entries + } + + #[getter] + fn max_ptx_size(&self) -> usize { + self.inner.max_ptx_size + } + + #[getter] + fn enable_eviction(&self) -> bool { + self.inner.enable_eviction + } + + #[getter] + fn ttl_seconds(&self) -> f64 { + self.inner.ttl_seconds + } + + fn __repr__(&self) -> String { + format!( + "CacheConfig(max_entries={}, max_ptx_size={}, eviction={})", + self.inner.max_entries, self.inner.max_ptx_size, self.inner.enable_eviction + ) + } +} + +/// Cached kernel entry for Python +#[pyclass(name = "CachedKernel")] +#[derive(Clone)] +pub struct PyCachedKernel { + inner: CachedKernel, +} + +#[pymethods] +impl PyCachedKernel { + #[getter] + fn key(&self) -> u64 { + self.inner.key + } + + #[getter] + fn name(&self) -> &str { + &self.inner.name + } + + #[getter] + fn ptx(&self) -> &str { + &self.inner.ptx + } + + #[getter] + fn module_handle(&self) -> Option { + self.inner.module_handle + } + + #[getter] + fn function_handle(&self) -> Option { + self.inner.function_handle + } + + #[getter] + fn created_at(&self) -> f64 { + self.inner.created_at + } + + #[getter] + fn last_access(&self) -> f64 { + self.inner.last_access + } + + #[getter] + fn access_count(&self) -> usize { + self.inner.access_count + } + + /// Check if kernel is loaded (has function handle) + fn is_loaded(&self) -> bool { + self.inner.is_loaded() + } + + fn __repr__(&self) -> String { + format!( + "CachedKernel(name='{}', loaded={}, accesses={})", + self.inner.name, self.inner.is_loaded(), self.inner.access_count + ) + } +} + +/// Cache statistics for Python +#[pyclass(name = "CacheStats")] +#[derive(Clone)] +pub struct PyCacheStats { + inner: CacheStats, +} + +#[pymethods] +impl PyCacheStats { + #[getter] + fn hits(&self) -> usize { + self.inner.hits + } + + #[getter] + fn misses(&self) -> usize { + self.inner.misses + } + + #[getter] + fn entries(&self) -> usize { + self.inner.entries + } + + #[getter] + fn ptx_size(&self) -> usize { + self.inner.ptx_size + } + + #[getter] + fn evictions(&self) -> usize { + self.inner.evictions + } + + #[getter] + fn ttl_evictions(&self) -> usize { + self.inner.ttl_evictions + } + + #[getter] + fn loaded_count(&self) -> usize { + self.inner.loaded_count + } + + /// Calculate hit rate (0.0 - 1.0) + fn hit_rate(&self) -> f64 { + self.inner.hit_rate() + } + + fn __repr__(&self) -> String { + format!( + "CacheStats(entries={}, hit_rate={:.1}%, ptx_size={})", + self.inner.entries, + self.inner.hit_rate() * 100.0, + self.inner.ptx_size + ) + } +} + +/// Kernel cache for Python +/// +/// Caches compiled CUDA kernels (PTX) to avoid repeated +/// NVRTC compilation. +#[pyclass(name = "KernelCache")] +pub struct PyKernelCache { + inner: KernelCache, +} + +#[pymethods] +impl PyKernelCache { + #[new] + #[pyo3(signature = (config=None))] + fn new(config: Option) -> Self { + let cfg = config.map(|c| c.inner).unwrap_or_default(); + Self { + inner: KernelCache::new(cfg), + } + } + + /// Compute cache key from source and options + #[staticmethod] + fn compute_key(source: &str, name: &str, options: &PyCompileOptions) -> u64 { + KernelCache::compute_key(source, name, &options.inner) + } + + /// Compute source hash + #[staticmethod] + fn hash_source(source: &str) -> u64 { + KernelCache::hash_source(source) + } + + /// Get cached kernel by key + fn get(&mut self, key: u64) -> Option { + self.inner.get(key).map(|k| PyCachedKernel { inner: k.clone() }) + } + + /// Get cached kernel by name and options + fn get_by_name(&mut self, name: &str, options: &PyCompileOptions) -> Option { + self.inner.get_by_name(name, &options.inner).map(|k| PyCachedKernel { inner: k.clone() }) + } + + /// Insert a compiled kernel + fn insert(&mut self, source: &str, name: &str, ptx: &str, options: PyCompileOptions) -> u64 { + self.inner.insert(source, name, ptx.into(), options.inner) + } + + /// Set module and function handles for a cached kernel + fn set_handles(&mut self, key: u64, module: u64, function: u64) -> bool { + self.inner.set_handles(key, module, function) + } + + /// Remove a kernel from cache + fn remove(&mut self, key: u64) -> Option { + self.inner.remove(key).map(|k| PyCachedKernel { inner: k }) + } + + /// Check if kernel is cached + fn contains(&self, key: u64) -> bool { + self.inner.contains(key) + } + + /// Get all cached kernel names + fn kernel_names(&self) -> Vec { + self.inner.kernel_names().into_iter().map(|s| s.to_string()).collect() + } + + /// Clear expired entries (TTL) + fn clear_expired(&mut self) -> usize { + self.inner.clear_expired() + } + + /// Get statistics + fn stats(&self) -> PyCacheStats { + PyCacheStats { + inner: self.inner.stats(), + } + } + + /// Get number of entries + fn __len__(&self) -> usize { + self.inner.len() + } + + /// Check if empty + fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Clear all cache + fn clear(&mut self) { + self.inner.clear(); + } + + fn __repr__(&self) -> String { + let stats = self.inner.stats(); + format!( + "KernelCache(entries={}, hit_rate={:.1}%)", + stats.entries, stats.hit_rate() * 100.0 + ) + } +} + /// Register dispatch module pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; @@ -413,5 +1353,23 @@ pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + // Pacing + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + // Slicing + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + // Cache + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/rust/pygpukit-python/src/lib.rs b/rust/pygpukit-python/src/lib.rs index 58e3e2e..cfb225a 100644 --- a/rust/pygpukit-python/src/lib.rs +++ b/rust/pygpukit-python/src/lib.rs @@ -46,6 +46,47 @@ fn _pygpukit_rust(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + // Pacing + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + // Admission control + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + // QoS policy + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + // Slicing + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + // Pinned memory + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + // Kernel cache + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + // Partitioning + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/rust/pygpukit-python/src/scheduler.rs b/rust/pygpukit-python/src/scheduler.rs index 1431433..38e8b6f 100644 --- a/rust/pygpukit-python/src/scheduler.rs +++ b/rust/pygpukit-python/src/scheduler.rs @@ -4,6 +4,10 @@ use pyo3::prelude::*; use std::sync::Arc; use pygpukit_core::scheduler::{ Scheduler, SchedulerStats, TaskMeta, TaskState, TaskPolicy, TaskStats, + AdmissionDecision, AdmissionStats, RejectReason, + QosClass, QosPolicy, QosTaskMeta, QosEvaluation, QosPolicyEvaluator, QosStats, + ResourceRequirements, + PartitionManager, PartitionConfig, Partition, PartitionLimits, PartitionUsage, PartitionStats, }; /// Task state enum for Python @@ -317,6 +321,672 @@ impl PyTaskStats { } } +/// Rejection reason enum for Python +#[pyclass(name = "RejectReason")] +#[derive(Clone)] +pub struct PyRejectReason { + inner: RejectReason, +} + +#[pymethods] +impl PyRejectReason { + /// Get rejection type as string + #[getter] + fn reason_type(&self) -> String { + match &self.inner { + RejectReason::InsufficientMemory { .. } => "InsufficientMemory".into(), + RejectReason::BandwidthExceeded { .. } => "BandwidthExceeded".into(), + RejectReason::QueueFull { .. } => "QueueFull".into(), + RejectReason::UnsatisfiableDependencies { .. } => "UnsatisfiableDependencies".into(), + RejectReason::Custom(_) => "Custom".into(), + } + } + + /// Get requested memory (if InsufficientMemory) + #[getter] + fn requested_memory(&self) -> Option { + match &self.inner { + RejectReason::InsufficientMemory { requested, .. } => Some(*requested), + _ => None, + } + } + + /// Get available memory (if InsufficientMemory) + #[getter] + fn available_memory(&self) -> Option { + match &self.inner { + RejectReason::InsufficientMemory { available, .. } => Some(*available), + _ => None, + } + } + + /// Get message + fn message(&self) -> String { + match &self.inner { + RejectReason::InsufficientMemory { requested, available } => { + format!("Insufficient memory: requested {} bytes, available {} bytes", requested, available) + } + RejectReason::BandwidthExceeded { requested_bw, available_bw } => { + format!("Bandwidth exceeded: requested {:.2}, available {:.2}", requested_bw, available_bw) + } + RejectReason::QueueFull { current, max } => { + format!("Queue full: {} tasks pending, max {}", current, max) + } + RejectReason::UnsatisfiableDependencies { missing } => { + format!("Unsatisfiable dependencies: {:?}", missing) + } + RejectReason::Custom(msg) => msg.clone(), + } + } + + fn __repr__(&self) -> String { + format!("RejectReason({})", self.message()) + } +} + +/// Admission decision for Python +#[pyclass(name = "AdmissionDecision")] +#[derive(Clone)] +pub struct PyAdmissionDecision { + inner: AdmissionDecision, +} + +#[pymethods] +impl PyAdmissionDecision { + /// Check if admitted + fn is_admitted(&self) -> bool { + self.inner.is_admitted() + } + + /// Check if rejected + fn is_rejected(&self) -> bool { + self.inner.is_rejected() + } + + /// Check if queued + fn is_queued(&self) -> bool { + self.inner.is_queued() + } + + /// Get decision type as string + #[getter] + fn decision_type(&self) -> String { + match &self.inner { + AdmissionDecision::Admit { .. } => "Admit".into(), + AdmissionDecision::Reject { .. } => "Reject".into(), + AdmissionDecision::Queue { .. } => "Queue".into(), + } + } + + /// Get reserved memory (if admitted) + #[getter] + fn reserved_memory(&self) -> Option { + match &self.inner { + AdmissionDecision::Admit { reserved_memory, .. } => Some(*reserved_memory), + _ => None, + } + } + + /// Get reserved bandwidth (if admitted) + #[getter] + fn reserved_bandwidth(&self) -> Option { + match &self.inner { + AdmissionDecision::Admit { reserved_bandwidth, .. } => Some(*reserved_bandwidth), + _ => None, + } + } + + /// Get queue position (if queued) + #[getter] + fn queue_position(&self) -> Option { + match &self.inner { + AdmissionDecision::Queue { position, .. } => Some(*position), + _ => None, + } + } + + /// Get estimated wait time in ms (if queued) + #[getter] + fn estimated_wait_ms(&self) -> Option { + match &self.inner { + AdmissionDecision::Queue { estimated_wait_ms, .. } => Some(*estimated_wait_ms), + _ => None, + } + } + + /// Get rejection reason (if rejected) + #[getter] + fn rejection_reason(&self) -> Option { + match &self.inner { + AdmissionDecision::Reject { reason } => Some(PyRejectReason { inner: reason.clone() }), + _ => None, + } + } + + fn __repr__(&self) -> String { + match &self.inner { + AdmissionDecision::Admit { reserved_memory, reserved_bandwidth } => { + format!("AdmissionDecision(Admit, memory={}, bandwidth={:.4})", reserved_memory, reserved_bandwidth) + } + AdmissionDecision::Reject { reason } => { + format!("AdmissionDecision(Reject, reason={})", PyRejectReason { inner: reason.clone() }.message()) + } + AdmissionDecision::Queue { position, estimated_wait_ms } => { + format!("AdmissionDecision(Queue, position={}, wait={:.1}ms)", position, estimated_wait_ms) + } + } + } +} + +/// Admission statistics for Python +#[pyclass(name = "AdmissionStats")] +#[derive(Clone)] +pub struct PyAdmissionStats { + inner: AdmissionStats, +} + +#[pymethods] +impl PyAdmissionStats { + /// Total admitted tasks + #[getter] + fn admitted_count(&self) -> usize { + self.inner.admitted_count + } + + /// Total rejected tasks + #[getter] + fn rejected_count(&self) -> usize { + self.inner.rejected_count + } + + /// Total queued tasks (best-effort) + #[getter] + fn queued_count(&self) -> usize { + self.inner.queued_count + } + + /// Current pending tasks + #[getter] + fn pending_count(&self) -> usize { + self.inner.pending_count + } + + /// Currently reserved memory + #[getter] + fn reserved_memory(&self) -> usize { + self.inner.reserved_memory + } + + /// Currently reserved bandwidth + #[getter] + fn reserved_bandwidth(&self) -> f64 { + self.inner.reserved_bandwidth + } + + /// Available memory + #[getter] + fn available_memory(&self) -> usize { + self.inner.available_memory + } + + /// Available bandwidth + #[getter] + fn available_bandwidth(&self) -> f64 { + self.inner.available_bandwidth + } + + fn __repr__(&self) -> String { + format!( + "AdmissionStats(admitted={}, rejected={}, queued={}, pending={})", + self.inner.admitted_count, self.inner.rejected_count, + self.inner.queued_count, self.inner.pending_count + ) + } +} + +// ============================================================================= +// QoS Policy Types +// ============================================================================= + +/// QoS class enum for Python +#[pyclass(name = "QosClass", eq, eq_int)] +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum PyQosClass { + Guaranteed = 0, + Burstable = 1, + BestEffort = 2, +} + +impl From for PyQosClass { + fn from(class: QosClass) -> Self { + match class { + QosClass::Guaranteed => PyQosClass::Guaranteed, + QosClass::Burstable => PyQosClass::Burstable, + QosClass::BestEffort => PyQosClass::BestEffort, + } + } +} + +impl From for QosClass { + fn from(class: PyQosClass) -> Self { + match class { + PyQosClass::Guaranteed => QosClass::Guaranteed, + PyQosClass::Burstable => QosClass::Burstable, + PyQosClass::BestEffort => QosClass::BestEffort, + } + } +} + +#[pymethods] +impl PyQosClass { + /// Get scheduling priority for this QoS class + fn priority(&self) -> i32 { + QosClass::from(*self).priority() + } + + /// Check if this class can preempt another + fn can_preempt(&self, other: PyQosClass) -> bool { + QosClass::from(*self).can_preempt(&QosClass::from(other)) + } + + /// Get memory overcommit ratio + fn memory_overcommit_ratio(&self) -> f64 { + QosClass::from(*self).memory_overcommit_ratio() + } + + /// Get bandwidth allocation ratio + fn bandwidth_ratio(&self) -> f64 { + QosClass::from(*self).bandwidth_ratio() + } +} + +/// Resource requirements for Python +#[pyclass(name = "ResourceRequirements")] +#[derive(Clone)] +pub struct PyResourceRequirements { + inner: ResourceRequirements, +} + +#[pymethods] +impl PyResourceRequirements { + /// Create new resource requirements + #[new] + #[pyo3(signature = (memory_request, memory_limit=None, bandwidth_request=0.0, bandwidth_limit=1.0))] + fn new( + memory_request: usize, + memory_limit: Option, + bandwidth_request: f64, + bandwidth_limit: f64, + ) -> Self { + Self { + inner: ResourceRequirements { + memory_request, + memory_limit: memory_limit.unwrap_or(memory_request), + bandwidth_request, + bandwidth_limit, + }, + } + } + + /// Create guaranteed requirements + #[staticmethod] + fn guaranteed(memory: usize) -> Self { + Self { + inner: ResourceRequirements::guaranteed(memory), + } + } + + /// Create burstable requirements + #[staticmethod] + fn burstable(memory_request: usize, burst_ratio: f64) -> Self { + Self { + inner: ResourceRequirements::burstable(memory_request, burst_ratio), + } + } + + /// Create best-effort requirements + #[staticmethod] + fn best_effort() -> Self { + Self { + inner: ResourceRequirements::best_effort(), + } + } + + #[getter] + fn memory_request(&self) -> usize { + self.inner.memory_request + } + + #[getter] + fn memory_limit(&self) -> usize { + self.inner.memory_limit + } + + #[getter] + fn bandwidth_request(&self) -> f64 { + self.inner.bandwidth_request + } + + #[getter] + fn bandwidth_limit(&self) -> f64 { + self.inner.bandwidth_limit + } + + fn __repr__(&self) -> String { + format!( + "ResourceRequirements(memory={}/{}, bandwidth={:.2}/{:.2})", + self.inner.memory_request, self.inner.memory_limit, + self.inner.bandwidth_request, self.inner.bandwidth_limit + ) + } +} + +/// QoS policy for Python +#[pyclass(name = "QosPolicy")] +#[derive(Clone)] +pub struct PyQosPolicy { + inner: QosPolicy, +} + +#[pymethods] +impl PyQosPolicy { + /// Create a Guaranteed policy + #[staticmethod] + fn guaranteed(memory: usize) -> Self { + Self { + inner: QosPolicy::guaranteed(memory), + } + } + + /// Create a Burstable policy + #[staticmethod] + fn burstable(memory_request: usize, burst_ratio: f64) -> Self { + Self { + inner: QosPolicy::burstable(memory_request, burst_ratio), + } + } + + /// Create a BestEffort policy + #[staticmethod] + fn best_effort() -> Self { + Self { + inner: QosPolicy::best_effort(), + } + } + + #[getter] + fn qos_class(&self) -> PyQosClass { + self.inner.class.into() + } + + /// Get memory to reserve + fn memory_to_reserve(&self) -> usize { + self.inner.memory_to_reserve() + } + + /// Get bandwidth to reserve + fn bandwidth_to_reserve(&self) -> f64 { + self.inner.bandwidth_to_reserve() + } + + fn __repr__(&self) -> String { + format!("QosPolicy({:?})", self.inner.class) + } +} + +/// QoS task metadata for Python +#[pyclass(name = "QosTaskMeta")] +#[derive(Clone)] +pub struct PyQosTaskMeta { + inner: QosTaskMeta, +} + +#[pymethods] +impl PyQosTaskMeta { + /// Create a Guaranteed task + #[staticmethod] + fn guaranteed(id: String, name: String, memory: usize) -> Self { + Self { + inner: QosTaskMeta::guaranteed(id, name, memory), + } + } + + /// Create a Burstable task + #[staticmethod] + fn burstable(id: String, name: String, memory_request: usize, burst_ratio: f64) -> Self { + Self { + inner: QosTaskMeta::burstable(id, name, memory_request, burst_ratio), + } + } + + /// Create a BestEffort task + #[staticmethod] + fn best_effort(id: String, name: String) -> Self { + Self { + inner: QosTaskMeta::best_effort(id, name), + } + } + + #[getter] + fn id(&self) -> String { + self.inner.task.id.clone() + } + + #[getter] + fn name(&self) -> String { + self.inner.task.name.clone() + } + + #[getter] + fn qos_class(&self) -> PyQosClass { + self.inner.qos.class.into() + } + + /// Get effective priority + fn effective_priority(&self) -> i32 { + self.inner.effective_priority() + } + + fn __repr__(&self) -> String { + format!( + "QosTaskMeta(id='{}', class={:?}, priority={})", + self.inner.task.id, self.inner.qos.class, self.inner.effective_priority() + ) + } +} + +/// QoS evaluation result for Python +#[pyclass(name = "QosEvaluation")] +#[derive(Clone)] +pub struct PyQosEvaluation { + inner: QosEvaluation, +} + +#[pymethods] +impl PyQosEvaluation { + fn is_admitted(&self) -> bool { + self.inner.is_admitted() + } + + fn is_throttled(&self) -> bool { + self.inner.is_throttled() + } + + fn is_queued(&self) -> bool { + self.inner.is_queued() + } + + fn is_rejected(&self) -> bool { + self.inner.is_rejected() + } + + #[getter] + fn decision_type(&self) -> String { + match &self.inner { + QosEvaluation::Admit { .. } => "Admit".into(), + QosEvaluation::Throttle { .. } => "Throttle".into(), + QosEvaluation::Queue { .. } => "Queue".into(), + QosEvaluation::Reject { .. } => "Reject".into(), + } + } + + #[getter] + fn qos_class(&self) -> Option { + match &self.inner { + QosEvaluation::Admit { class, .. } => Some((*class).into()), + QosEvaluation::Throttle { class, .. } => Some((*class).into()), + _ => None, + } + } + + #[getter] + fn reserved_memory(&self) -> Option { + match &self.inner { + QosEvaluation::Admit { reserved_memory, .. } => Some(*reserved_memory), + _ => None, + } + } + + #[getter] + fn reject_reason(&self) -> Option { + match &self.inner { + QosEvaluation::Reject { reason } => Some(reason.clone()), + _ => None, + } + } + + fn __repr__(&self) -> String { + match &self.inner { + QosEvaluation::Admit { class, reserved_memory, reserved_bandwidth } => { + format!("QosEvaluation(Admit, class={:?}, memory={}, bw={:.2})", + class, reserved_memory, reserved_bandwidth) + } + QosEvaluation::Throttle { class, allowed_memory, allowed_bandwidth } => { + format!("QosEvaluation(Throttle, class={:?}, allowed_mem={}, allowed_bw={:.2})", + class, allowed_memory, allowed_bandwidth) + } + QosEvaluation::Queue { position } => { + format!("QosEvaluation(Queue, position={})", position) + } + QosEvaluation::Reject { reason } => { + format!("QosEvaluation(Reject, reason='{}')", reason) + } + } + } +} + +/// QoS statistics for Python +#[pyclass(name = "QosStats")] +#[derive(Clone)] +pub struct PyQosStats { + inner: QosStats, +} + +#[pymethods] +impl PyQosStats { + #[getter] + fn total_memory(&self) -> usize { + self.inner.total_memory + } + + #[getter] + fn total_bandwidth(&self) -> f64 { + self.inner.total_bandwidth + } + + #[getter] + fn guaranteed_memory(&self) -> usize { + self.inner.guaranteed_memory + } + + #[getter] + fn guaranteed_bandwidth(&self) -> f64 { + self.inner.guaranteed_bandwidth + } + + #[getter] + fn burstable_memory(&self) -> usize { + self.inner.burstable_memory + } + + #[getter] + fn best_effort_queue(&self) -> usize { + self.inner.best_effort_queue + } + + #[getter] + fn available_memory(&self) -> usize { + self.inner.available_memory + } + + #[getter] + fn available_bandwidth(&self) -> f64 { + self.inner.available_bandwidth + } + + fn __repr__(&self) -> String { + format!( + "QosStats(guaranteed_mem={}, burstable_mem={}, best_effort_queue={})", + self.inner.guaranteed_memory, self.inner.burstable_memory, + self.inner.best_effort_queue + ) + } +} + +/// QoS policy evaluator for Python +#[pyclass(name = "QosPolicyEvaluator")] +pub struct PyQosPolicyEvaluator { + inner: QosPolicyEvaluator, +} + +#[pymethods] +impl PyQosPolicyEvaluator { + #[new] + #[pyo3(signature = (total_memory, total_bandwidth=1.0))] + fn new(total_memory: usize, total_bandwidth: f64) -> Self { + Self { + inner: QosPolicyEvaluator::new(total_memory, total_bandwidth), + } + } + + /// Evaluate QoS policy for a task + fn evaluate(&self, task: &PyQosTaskMeta) -> PyQosEvaluation { + PyQosEvaluation { + inner: self.inner.evaluate(&task.inner), + } + } + + /// Reserve resources for an admitted task + fn reserve(&mut self, evaluation: &PyQosEvaluation) { + self.inner.reserve(&evaluation.inner); + } + + /// Release resources when a task completes + fn release(&mut self, qos_class: PyQosClass, memory: usize, bandwidth: f64) { + self.inner.release(qos_class.into(), memory, bandwidth); + } + + /// Get statistics + fn stats(&self) -> PyQosStats { + PyQosStats { + inner: self.inner.stats(), + } + } + + /// Reset evaluator state + fn reset(&mut self) { + self.inner.reset(); + } + + fn __repr__(&self) -> String { + let stats = self.inner.stats(); + format!( + "QosPolicyEvaluator(avail_mem={}, avail_bw={:.2})", + stats.available_memory, stats.available_bandwidth + ) + } +} + /// Thread-safe task scheduler with bandwidth pacing. /// /// Args: @@ -350,6 +1020,30 @@ impl PyScheduler { self.inner.submit(task.inner) } + /// Admit a task through admission control. + /// + /// Returns an AdmissionDecision indicating whether the task + /// was admitted, queued, or rejected. + fn admit(&self, task: PyTaskMeta) -> PyAdmissionDecision { + PyAdmissionDecision { + inner: self.inner.admit(task.inner), + } + } + + /// Evaluate admission for a task without submitting. + fn evaluate_admission(&self, task: &PyTaskMeta) -> PyAdmissionDecision { + PyAdmissionDecision { + inner: self.inner.evaluate_admission(&task.inner), + } + } + + /// Get admission control statistics. + fn admission_stats(&self) -> PyAdmissionStats { + PyAdmissionStats { + inner: self.inner.admission_stats(), + } + } + /// Get tasks that are ready to run. #[pyo3(signature = (max_tasks=1))] fn get_runnable_tasks(&self, max_tasks: usize) -> Vec { @@ -435,6 +1129,409 @@ impl PyScheduler { } } +// ============================================================================= +// GPU Partitioning Types +// ============================================================================= + +/// Partition limits for Python +#[pyclass(name = "PartitionLimits")] +#[derive(Clone)] +pub struct PyPartitionLimits { + inner: PartitionLimits, +} + +#[pymethods] +impl PyPartitionLimits { + #[new] + #[pyo3(signature = (memory_quota=0, compute_share=1.0, bandwidth_share=1.0, max_streams=16))] + fn new(memory_quota: usize, compute_share: f64, bandwidth_share: f64, max_streams: usize) -> Self { + Self { + inner: PartitionLimits { + memory_quota, + compute_share, + bandwidth_share, + max_streams, + ..Default::default() + }, + } + } + + /// Create with memory quota + #[staticmethod] + fn with_memory(memory_quota: usize) -> Self { + Self { + inner: PartitionLimits::with_memory(memory_quota), + } + } + + /// Create with compute share + #[staticmethod] + fn with_compute(compute_share: f64) -> Self { + Self { + inner: PartitionLimits::with_compute(compute_share), + } + } + + /// Set memory quota + fn memory(&self, quota: usize) -> Self { + Self { + inner: self.inner.clone().memory(quota), + } + } + + /// Set compute share + fn compute(&self, share: f64) -> Self { + Self { + inner: self.inner.clone().compute(share), + } + } + + /// Set bandwidth share + fn bandwidth(&self, share: f64) -> Self { + Self { + inner: self.inner.clone().bandwidth(share), + } + } + + #[getter] + fn memory_quota(&self) -> usize { + self.inner.memory_quota + } + + #[getter] + fn compute_share(&self) -> f64 { + self.inner.compute_share + } + + #[getter] + fn bandwidth_share(&self) -> f64 { + self.inner.bandwidth_share + } + + #[getter] + fn max_streams(&self) -> usize { + self.inner.max_streams + } + + fn __repr__(&self) -> String { + format!( + "PartitionLimits(memory={}, compute={:.1}%, bandwidth={:.1}%)", + self.inner.memory_quota, + self.inner.compute_share * 100.0, + self.inner.bandwidth_share * 100.0 + ) + } +} + +/// Partition usage for Python +#[pyclass(name = "PartitionUsage")] +#[derive(Clone)] +pub struct PyPartitionUsage { + inner: PartitionUsage, +} + +#[pymethods] +impl PyPartitionUsage { + #[getter] + fn memory_used(&self) -> usize { + self.inner.memory_used + } + + #[getter] + fn active_streams(&self) -> usize { + self.inner.active_streams + } + + #[getter] + fn pending_kernels(&self) -> usize { + self.inner.pending_kernels + } + + #[getter] + fn pending_transfers(&self) -> usize { + self.inner.pending_transfers + } + + #[getter] + fn total_kernels(&self) -> usize { + self.inner.total_kernels + } + + #[getter] + fn total_transfers(&self) -> usize { + self.inner.total_transfers + } + + #[getter] + fn compute_time_ms(&self) -> f64 { + self.inner.compute_time_ms + } + + fn __repr__(&self) -> String { + format!( + "PartitionUsage(memory={}, streams={}, kernels={})", + self.inner.memory_used, self.inner.active_streams, self.inner.total_kernels + ) + } +} + +/// Partition for Python +#[pyclass(name = "Partition")] +#[derive(Clone)] +pub struct PyPartition { + inner: Partition, +} + +#[pymethods] +impl PyPartition { + #[getter] + fn id(&self) -> &str { + &self.inner.id + } + + #[getter] + fn name(&self) -> &str { + &self.inner.name + } + + #[getter] + fn limits(&self) -> PyPartitionLimits { + PyPartitionLimits { inner: self.inner.limits.clone() } + } + + #[getter] + fn usage(&self) -> PyPartitionUsage { + PyPartitionUsage { inner: self.inner.usage.clone() } + } + + #[getter] + fn tasks(&self) -> Vec { + self.inner.tasks.clone() + } + + #[getter] + fn enabled(&self) -> bool { + self.inner.enabled + } + + /// Get memory utilization (0.0 - 1.0) + fn memory_utilization(&self) -> f64 { + self.inner.memory_utilization() + } + + fn __repr__(&self) -> String { + format!( + "Partition(id='{}', name='{}', enabled={})", + self.inner.id, self.inner.name, self.inner.enabled + ) + } +} + +/// Partition config for Python +#[pyclass(name = "PartitionConfig")] +#[derive(Clone)] +pub struct PyPartitionConfig { + inner: PartitionConfig, +} + +#[pymethods] +impl PyPartitionConfig { + #[new] + #[pyo3(signature = (total_memory=8589934592, allow_overcommit=false, overcommit_ratio=1.0))] + fn new(total_memory: usize, allow_overcommit: bool, overcommit_ratio: f64) -> Self { + Self { + inner: PartitionConfig { + total_memory, + allow_overcommit, + overcommit_ratio, + ..Default::default() + }, + } + } + + #[getter] + fn total_memory(&self) -> usize { + self.inner.total_memory + } + + #[getter] + fn allow_overcommit(&self) -> bool { + self.inner.allow_overcommit + } + + #[getter] + fn overcommit_ratio(&self) -> f64 { + self.inner.overcommit_ratio + } + + fn __repr__(&self) -> String { + format!( + "PartitionConfig(memory={}, overcommit={})", + self.inner.total_memory, self.inner.allow_overcommit + ) + } +} + +/// Partition statistics for Python +#[pyclass(name = "PartitionStats")] +#[derive(Clone)] +pub struct PyPartitionStats { + inner: PartitionStats, +} + +#[pymethods] +impl PyPartitionStats { + #[getter] + fn partition_count(&self) -> usize { + self.inner.partition_count + } + + #[getter] + fn active_partitions(&self) -> usize { + self.inner.active_partitions + } + + #[getter] + fn total_memory_allocated(&self) -> usize { + self.inner.total_memory_allocated + } + + #[getter] + fn total_compute_allocated(&self) -> f64 { + self.inner.total_compute_allocated + } + + #[getter] + fn total_bandwidth_allocated(&self) -> f64 { + self.inner.total_bandwidth_allocated + } + + #[getter] + fn available_memory(&self) -> usize { + self.inner.available_memory + } + + #[getter] + fn available_compute(&self) -> f64 { + self.inner.available_compute + } + + #[getter] + fn available_bandwidth(&self) -> f64 { + self.inner.available_bandwidth + } + + fn __repr__(&self) -> String { + format!( + "PartitionStats(partitions={}, memory_alloc={}, compute_alloc={:.1}%)", + self.inner.partition_count, + self.inner.total_memory_allocated, + self.inner.total_compute_allocated * 100.0 + ) + } +} + +/// Partition manager for Python +/// +/// Manages GPU resource partitions for multi-tenant or multi-task isolation. +#[pyclass(name = "PartitionManager")] +pub struct PyPartitionManager { + inner: PartitionManager, +} + +#[pymethods] +impl PyPartitionManager { + #[new] + #[pyo3(signature = (config=None))] + fn new(config: Option) -> Self { + let cfg = config.map(|c| c.inner).unwrap_or_default(); + Self { + inner: PartitionManager::new(cfg), + } + } + + /// Create with total memory + #[staticmethod] + fn with_memory(total_memory: usize) -> Self { + Self { + inner: PartitionManager::with_memory(total_memory), + } + } + + /// Create a new partition + fn create_partition(&mut self, id: &str, name: &str, limits: PyPartitionLimits) -> PyResult<()> { + self.inner.create_partition(id, name, limits.inner).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(e.to_string()) + }) + } + + /// Delete a partition + fn delete_partition(&mut self, id: &str) -> PyResult { + self.inner.delete_partition(id) + .map(|p| PyPartition { inner: p }) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) + } + + /// Get a partition + fn get(&self, id: &str) -> Option { + self.inner.get(id).map(|p| PyPartition { inner: p.clone() }) + } + + /// Assign a task to a partition + fn assign_task(&mut self, task_id: &str, partition_id: &str) -> PyResult<()> { + self.inner.assign_task(task_id, partition_id).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(e.to_string()) + }) + } + + /// Get partition for a task + fn get_task_partition(&self, task_id: &str) -> Option { + self.inner.get_task_partition(task_id).map(|p| PyPartition { inner: p.clone() }) + } + + /// Unassign a task from its partition + fn unassign_task(&mut self, task_id: &str) { + self.inner.unassign_task(task_id); + } + + /// Set default partition + fn set_default(&mut self, id: &str) -> PyResult<()> { + self.inner.set_default(id).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(e.to_string()) + }) + } + + /// Get default partition + fn default_partition(&self) -> Option { + self.inner.default_partition().map(|p| PyPartition { inner: p.clone() }) + } + + /// List all partition IDs + fn partition_ids(&self) -> Vec { + self.inner.partition_ids().into_iter().map(|s| s.to_string()).collect() + } + + /// Get statistics + fn stats(&self) -> PyPartitionStats { + PyPartitionStats { + inner: self.inner.stats(), + } + } + + /// Clear all partitions + fn clear(&mut self) { + self.inner.clear(); + } + + fn __repr__(&self) -> String { + let stats = self.inner.stats(); + format!( + "PartitionManager(partitions={}, memory_alloc={})", + stats.partition_count, stats.total_memory_allocated + ) + } +} + /// Register scheduler module pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; @@ -443,5 +1540,24 @@ pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + // Admission control + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + // QoS policy + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + // Partitioning + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/rust/pygpukit-python/src/transfer.rs b/rust/pygpukit-python/src/transfer.rs index dd78521..a6db9fb 100644 --- a/rust/pygpukit-python/src/transfer.rs +++ b/rust/pygpukit-python/src/transfer.rs @@ -3,6 +3,7 @@ use pyo3::prelude::*; use pygpukit_core::transfer::{ AsyncTransferEngine, TransferOp, TransferState, TransferStats, TransferType, StreamType, + PinnedMemoryManager, PinnedPoolConfig, PinnedBlock, PinnedStats, PinnedError, }; /// Python wrapper for TransferType enum @@ -448,6 +449,315 @@ impl PyAsyncTransferEngine { } } +// ============================================================================= +// Pinned Memory Types +// ============================================================================= + +/// Pinned memory pool configuration for Python +#[pyclass(name = "PinnedPoolConfig")] +#[derive(Clone)] +pub struct PyPinnedPoolConfig { + inner: PinnedPoolConfig, +} + +#[pymethods] +impl PyPinnedPoolConfig { + #[new] + #[pyo3(signature = (max_size=1073741824, enable_pooling=true, alignment=256))] + fn new(max_size: usize, enable_pooling: bool, alignment: usize) -> Self { + Self { + inner: PinnedPoolConfig { + max_size, + enable_pooling, + alignment, + ..Default::default() + }, + } + } + + #[getter] + fn max_size(&self) -> usize { + self.inner.max_size + } + + #[getter] + fn enable_pooling(&self) -> bool { + self.inner.enable_pooling + } + + #[getter] + fn alignment(&self) -> usize { + self.inner.alignment + } + + /// Get size class for a given size + fn get_size_class(&self, size: usize) -> usize { + self.inner.get_size_class(size) + } + + fn __repr__(&self) -> String { + format!( + "PinnedPoolConfig(max_size={}, pooling={}, alignment={})", + self.inner.max_size, self.inner.enable_pooling, self.inner.alignment + ) + } +} + +/// Pinned memory block for Python +#[pyclass(name = "PinnedBlock")] +#[derive(Clone)] +pub struct PyPinnedBlock { + inner: PinnedBlock, +} + +#[pymethods] +impl PyPinnedBlock { + #[getter] + fn id(&self) -> u64 { + self.inner.id + } + + #[getter] + fn host_ptr(&self) -> u64 { + self.inner.host_ptr + } + + #[getter] + fn size(&self) -> usize { + self.inner.size + } + + #[getter] + fn in_use(&self) -> bool { + self.inner.in_use + } + + #[getter] + fn task_id(&self) -> Option { + self.inner.task_id.clone() + } + + #[getter] + fn allocated_at(&self) -> f64 { + self.inner.allocated_at + } + + #[getter] + fn last_access(&self) -> f64 { + self.inner.last_access + } + + fn __repr__(&self) -> String { + format!( + "PinnedBlock(id={}, size={}, in_use={})", + self.inner.id, self.inner.size, self.inner.in_use + ) + } +} + +/// Pinned memory statistics for Python +#[pyclass(name = "PinnedStats")] +#[derive(Clone)] +pub struct PyPinnedStats { + inner: PinnedStats, +} + +#[pymethods] +impl PyPinnedStats { + #[getter] + fn total_allocated(&self) -> usize { + self.inner.total_allocated + } + + #[getter] + fn current_used(&self) -> usize { + self.inner.current_used + } + + #[getter] + fn peak_used(&self) -> usize { + self.inner.peak_used + } + + #[getter] + fn total_allocations(&self) -> usize { + self.inner.total_allocations + } + + #[getter] + fn total_frees(&self) -> usize { + self.inner.total_frees + } + + #[getter] + fn pool_hits(&self) -> usize { + self.inner.pool_hits + } + + #[getter] + fn pool_misses(&self) -> usize { + self.inner.pool_misses + } + + #[getter] + fn pool_size(&self) -> usize { + self.inner.pool_size + } + + #[getter] + fn pooled_blocks(&self) -> usize { + self.inner.pooled_blocks + } + + /// Pool hit rate (0.0 - 1.0) + fn hit_rate(&self) -> f64 { + let total = self.inner.pool_hits + self.inner.pool_misses; + if total > 0 { + self.inner.pool_hits as f64 / total as f64 + } else { + 0.0 + } + } + + fn __repr__(&self) -> String { + format!( + "PinnedStats(used={}/{}, hit_rate={:.1}%)", + self.inner.current_used, self.inner.peak_used, + self.hit_rate() * 100.0 + ) + } +} + +/// Pinned memory manager for Python +/// +/// Manages page-locked (pinned) host memory for faster +/// CPU-GPU transfers with optional pooling for reuse. +#[pyclass(name = "PinnedMemoryManager")] +pub struct PyPinnedMemoryManager { + inner: PinnedMemoryManager, +} + +#[pymethods] +impl PyPinnedMemoryManager { + #[new] + #[pyo3(signature = (config=None))] + fn new(config: Option) -> Self { + let cfg = config.map(|c| c.inner).unwrap_or_default(); + Self { + inner: PinnedMemoryManager::new(cfg), + } + } + + /// Create with max size + #[staticmethod] + fn with_max_size(max_size: usize) -> Self { + Self { + inner: PinnedMemoryManager::with_max_size(max_size), + } + } + + /// Check if allocation would succeed + fn can_allocate(&self, size: usize) -> bool { + self.inner.can_allocate(size) + } + + /// Allocate pinned memory + /// + /// Returns (id, size_class, reused) tuple. + /// If reused=False, caller must perform cudaHostAlloc + /// and then call register(). + fn allocate(&mut self, size: usize) -> PyResult<(u64, usize, bool)> { + self.inner.allocate(size).map_err(|e| { + pyo3::exceptions::PyMemoryError::new_err(e.to_string()) + }) + } + + /// Register an allocated block + /// + /// Call this after cudaHostAlloc succeeds with the actual pointer. + fn register(&mut self, id: u64, host_ptr: u64, size: usize) { + self.inner.register(id, host_ptr, size); + } + + /// Free a pinned block + /// + /// Returns (should_free, host_ptr) tuple. + /// If should_free=True, caller should call cudaFreeHost. + fn free(&mut self, id: u64) -> PyResult<(bool, u64)> { + self.inner.free(id).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(e.to_string()) + }) + } + + /// Associate a block with a task + fn associate_task(&mut self, id: u64, task_id: String) -> PyResult<()> { + self.inner.associate_task(id, task_id).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(e.to_string()) + }) + } + + /// Get a block by ID + fn get(&self, id: u64) -> Option { + self.inner.get(id).map(|b| PyPinnedBlock { inner: b.clone() }) + } + + /// Touch a block to update access time + fn touch(&mut self, id: u64) -> PyResult<()> { + self.inner.touch(id).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(e.to_string()) + }) + } + + /// Get blocks for a task + fn get_blocks_for_task(&self, task_id: &str) -> Vec { + self.inner.get_blocks_for_task(task_id) + .into_iter() + .map(|b| PyPinnedBlock { inner: b.clone() }) + .collect() + } + + /// Free all blocks for a task + /// + /// Returns list of (should_free, host_ptr) tuples. + fn free_task_blocks(&mut self, task_id: &str) -> Vec<(bool, u64)> { + self.inner.free_task_blocks(task_id) + } + + /// Get statistics + fn stats(&self) -> PyPinnedStats { + PyPinnedStats { + inner: self.inner.stats(), + } + } + + /// Clear the pool (free all pooled blocks) + /// + /// Returns list of host pointers to free. + fn clear_pool(&mut self) -> Vec { + self.inner.clear_pool() + } + + /// Clear all state + /// + /// Returns list of all host pointers to free. + fn clear(&mut self) -> Vec { + self.inner.clear() + } + + /// Get number of active blocks + fn active_count(&self) -> usize { + self.inner.active_count() + } + + fn __repr__(&self) -> String { + let stats = self.inner.stats(); + format!( + "PinnedMemoryManager(active={}, used={} bytes)", + self.inner.active_count(), stats.current_used + ) + } +} + /// Register transfer module pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; @@ -456,5 +766,10 @@ pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + // Pinned memory + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) }