Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
361 changes: 361 additions & 0 deletions benchmark_knn_direct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,361 @@
#!/usr/bin/env python3

import torch
import numpy as np
import time
import sys
import os

from prior_depth_anything.depth_completion import DepthCompletion


def validate_knn_distances(sparse_masks, complete_masks, knn_indices, reported_distances, method_name):
"""
Validate that reported distances match manually computed distances from knn_indices.

Args:
sparse_masks: Boolean mask indicating sparse points
complete_masks: Boolean mask indicating points to complete
knn_indices: The KNN indices returned by the algorithm [num_complete, K]
reported_distances: The distances returned by the algorithm [num_complete, K]
method_name: Name for logging

Returns:
bool: True if distances are valid, False otherwise
"""
# Get the device from the input tensors
device = sparse_masks.device

# Extract coordinates the same way the KNN algorithms do
batch_sparse = torch.nonzero(sparse_masks, as_tuple=False)[..., [0, 2, 1]].float()
batch_complete = torch.nonzero(complete_masks, as_tuple=False)[..., [0, 2, 1]].float()

x = batch_sparse[:, -2:].contiguous() # sparse coordinates [num_sparse, 2]
y = batch_complete[:, -2:].contiguous() # complete coordinates [num_complete, 2]

# Ensure all tensors are on the same device
knn_indices = knn_indices.to(device)
reported_distances = reported_distances.to(device)

# Vectorized distance computation
num_complete, K = knn_indices.shape

# Get sparse points for each complete point using advanced indexing
# knn_indices: [num_complete, K]
# x[knn_indices]: [num_complete, K, 2] - coordinates of nearest neighbors
sparse_neighbors = x[knn_indices] # Shape: [num_complete, K, 2]

# Expand complete points to match neighbor dimensions
y_expanded = y.unsqueeze(1).expand(-1, K, -1) # Shape: [num_complete, K, 2]

# Compute distances vectorized: sqrt(sum((y - x)^2))
manually_computed_distances = torch.norm(y_expanded - sparse_neighbors, dim=2)

# Compare manually computed vs reported distances
distance_diff = torch.abs(manually_computed_distances - reported_distances)
max_diff = distance_diff.max().item()
mean_diff = distance_diff.mean().item()

tolerance = 1e-6
distances_valid = torch.allclose(manually_computed_distances, reported_distances, atol=tolerance, rtol=tolerance)

print(f"{method_name}: Max diff = {max_diff:.8f}, Mean diff = {mean_diff:.8f}, Valid = {distances_valid}")

return distances_valid


def create_synthetic_sparse_data(height=480, width=640, sparsity_ratio=0.99, device='cuda', batch_size=1):
"""
Create synthetic sparse disparity data for testing KNN algorithms.
"""
# Create synthetic dense disparities (simulating a complete depth map converted to disparity)
y_coords, x_coords = torch.meshgrid(
torch.linspace(0, 1, height, device=device),
torch.linspace(0, 1, width, device=device),
indexing='ij'
)
# Create a realistic disparity pattern (inverse depth)
dense_disparities = 0.1 + 0.9 * (0.3 * x_coords + 0.7 * y_coords +
0.1 * torch.sin(x_coords * 8) * torch.cos(y_coords * 8))
dense_disparities = dense_disparities.unsqueeze(0).repeat(batch_size, 1, 1)

# Create predicted disparities (slightly different from ground truth)
pred_disparities = dense_disparities + 0.02 * torch.randn_like(dense_disparities)

# Create sparse version by randomly masking pixels
total_pixels = height * width
num_sparse_points = int(total_pixels * (1 - sparsity_ratio))

sparse_disparities = torch.zeros_like(dense_disparities)
sparse_masks = torch.zeros_like(dense_disparities, dtype=torch.bool)
complete_masks = torch.ones_like(dense_disparities, dtype=torch.bool)

for b in range(batch_size):
# Randomly select pixels to keep as sparse points
flat_indices = torch.randperm(total_pixels, device=device)[:num_sparse_points]
row_indices = flat_indices // width
col_indices = flat_indices % width

sparse_disparities[b, row_indices, col_indices] = dense_disparities[b, row_indices, col_indices]
sparse_masks[b, row_indices, col_indices] = True
complete_masks[b, row_indices, col_indices] = False # Don't complete sparse points

print(f"Created synthetic data:")
print(f" Shape: {sparse_disparities.shape}")
print(f" Sparse points: {sparse_masks.sum().item()}")
print(f" Complete points: {complete_masks.sum().item()}")
print(f" Sparsity: {(sparse_masks == 0).float().mean().item()*100:.1f}% zeros")

return {
'sparse_disparities': sparse_disparities,
'pred_disparities': pred_disparities,
'sparse_masks': sparse_masks,
'complete_masks': complete_masks
}


def create_dummy_depth_completion():
"""Create a DepthCompletion instance just to access the KNN methods."""
class DummyArgs:
def __init__(self):
self.K = 5
self.extra_condition = 'error'
self.normalize_confidence = False
self.frozen_model_size = 'vitb'
self.double_global = False

args = DummyArgs()
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# We'll create the instance without initializing the depth model
depth_completion = DepthCompletion.__new__(DepthCompletion)
depth_completion.args = args
depth_completion.K = args.K
depth_completion.device = torch.device(device)
depth_completion.depth_model = None # We don't need this for KNN testing

return depth_completion


def benchmark_knn_direct(sparsity_ratios=[0.90, 0.95, 0.99, 0.995], n_runs=5, K=5, height=480, width=640):
"""
Directly benchmark the KNN alignment functions.

Args:
sparsity_ratios: List of sparsity ratios to test
n_runs: Number of benchmark runs per test
K: Number of nearest neighbors
height: Image height for testing
width: Image width for testing
"""
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"K nearest neighbors: {K}")
print(f"Resolution: {height}x{width}")

# Create depth completion instance for method access
depth_completion = create_dummy_depth_completion()

results = {}

print(f"\n{'='*80}")
print("BENCHMARKING KNN ALIGNMENT METHODS DIRECTLY")
print('='*80)

for sparsity in sparsity_ratios:
print(f"\n📊 Testing sparsity: {sparsity*100:.1f}% zeros")
print("-" * 50)

data = create_synthetic_sparse_data(
height=height, width=width,
sparsity_ratio=sparsity,
device=device
)

sparse_disparities = data['sparse_disparities']
pred_disparities = data['pred_disparities']
sparse_masks = data['sparse_masks']
complete_masks = data['complete_masks']

# Test both KNN implementations
modes = ['KD-tree', 'torch_cluster']
sparsity_results = {}
mode_outputs = {}

for mode in modes:
use_kdtree = (mode == 'KD-tree')
mode_times = []

print(f"\n 🔄 Testing {mode} mode...")

# Warm up
try:
_ = depth_completion.knn_aligns(
sparse_disparities=sparse_disparities,
pred_disparities=pred_disparities,
sparse_masks=sparse_masks,
complete_masks=complete_masks,
K=K,
kd_tree=use_kdtree,
return_indices=True
)
print(f" ✅ Warm-up successful")
except Exception as e:
print(f" ❌ Warm-up failed: {e}")
continue

# Benchmark runs
valid_runs = 0
for run in range(n_runs):
torch.cuda.synchronize() if torch.cuda.is_available() else None
start_time = time.time()

try:
dists, k_sparse_targets, k_pred_targets, knn_indices = depth_completion.knn_aligns(
sparse_disparities=sparse_disparities,
pred_disparities=pred_disparities,
sparse_masks=sparse_masks,
complete_masks=complete_masks,
K=K,
kd_tree=use_kdtree,
return_indices=True
)

torch.cuda.synchronize() if torch.cuda.is_available() else None
run_time = time.time() - start_time
mode_times.append(run_time)
valid_runs += 1

# Store output for the first run for comparison
if run == 0:
mode_outputs[mode] = {
'dists': dists.cpu(),
'k_sparse_targets': k_sparse_targets.cpu(),
'k_pred_targets': k_pred_targets.cpu(),
'knn_indices': knn_indices.cpu()
}

print(f" Run {run+1}: {run_time:.6f}s")

except Exception as e:
print(f" ❌ Run {run+1} failed: {e}")
import traceback
traceback.print_exc()
break

if mode_times and valid_runs > 0:
avg_time = np.mean(mode_times)
std_time = np.std(mode_times)
sparsity_results[mode] = {
'times': mode_times,
'avg_time': avg_time,
'std_time': std_time,
'valid_runs': valid_runs
}
print(f" 📈 Average ({valid_runs} runs): {avg_time:.6f} ± {std_time:.6f}s")

# Compare outputs between methods
if len(mode_outputs) == 2:
print(f"\n 🔍 Comparing outputs...")
kdtree_out = mode_outputs['KD-tree']
cluster_out = mode_outputs['torch_cluster']

# Compare distances (should be identical)
dists_diff = torch.abs(kdtree_out['dists'] - cluster_out['dists'])
print(f" Distances - Max diff: {dists_diff.max().item():.8f}, Mean diff: {dists_diff.mean().item():.8f}")

# For sparse/pred targets, we need to sort each K-neighbor group to handle tie-breaking differences
# Each complete point has K neighbors, so reshape to [num_complete, K, ...]
def sort_knn_results(targets, indices):
"""Sort K-neighbors for each complete point by their indices to ensure consistent ordering."""
# targets and indices are [num_complete, K, ...]
sorted_indices = torch.argsort(indices, dim=1) # Sort by indices for consistency
sorted_targets = torch.gather(targets, 1, sorted_indices)
return sorted_targets

# Overall match check
tolerance = 1e-6
dists_match = torch.allclose(kdtree_out['dists'], cluster_out['dists'], atol=tolerance, rtol=tolerance)
dists_correct_kdtree = validate_knn_distances(
sparse_masks, complete_masks,
kdtree_out['knn_indices'], kdtree_out['dists'], 'KD-tree'
)
dists_correct_cluster = validate_knn_distances(
sparse_masks, complete_masks,
cluster_out['knn_indices'], cluster_out['dists'], 'torch_cluster'
)

if dists_match and dists_correct_kdtree and dists_correct_cluster:
print(f" ✅ All outputs match within tolerance ({tolerance})")
else:
if not dists_match:
print(f" ❌ Distances do not match between methods!")
if not dists_correct_kdtree:
print(f" ❌ KD-tree distances do not match manually computed distances. Likely indexing errors.")
if not dists_correct_cluster:
print(f" ❌ torch_cluster distances do not match manually computed distances. Likely indexing errors.")
exit()

results[sparsity] = sparsity_results

# Print summary
print(f"\n{'='*80}")
print("PERFORMANCE SUMMARY")
print('='*80)

for sparsity in sparsity_ratios:
if sparsity not in results:
continue

print(f"\nSparsity {sparsity*100:.1f}% zeros:")
sparsity_results = results[sparsity]

if 'KD-tree' in sparsity_results and 'torch_cluster' in sparsity_results:
kdtree_time = sparsity_results['KD-tree']['avg_time']
torch_cluster_time = sparsity_results['torch_cluster']['avg_time']

print(f" KD-tree: {kdtree_time:.6f} ± {sparsity_results['KD-tree']['std_time']:.6f}s")
print(f" torch_cluster: {torch_cluster_time:.6f} ± {sparsity_results['torch_cluster']['std_time']:.6f}s")

if kdtree_time < torch_cluster_time:
speedup = torch_cluster_time / kdtree_time
print(f" 🏆 KD-tree is {speedup:.2f}x faster")
else:
speedup = kdtree_time / torch_cluster_time
print(f" 🏆 torch_cluster is {speedup:.2f}x faster")

return results


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description='Benchmark KNN implementations directly')
parser.add_argument('--sparsity', type=float, nargs='+',
default=[0.60, 0.90],
help='Sparsity ratios to test')
parser.add_argument('--runs', type=int, default=5,
help='Number of runs per test')
parser.add_argument('--K', type=int, default=5,
help='Number of nearest neighbors')
parser.add_argument('--height', type=int, default=1280,
help='Image height for testing')
parser.add_argument('--width', type=int, default=1920,
help='Image width for testing')

args = parser.parse_args()

print("🚀 Direct KNN Benchmark")
print("="*80)

try:
results = benchmark_knn_direct(args.sparsity, args.runs, args.K, args.height, args.width)

print(f"\n{'='*80}")
print("✅ BENCHMARK COMPLETE")
print('='*80)

except Exception as e:
print(f"❌ Benchmark failed: {e}")
import traceback
traceback.print_exc()
Loading