From b7bf2450b0df92f7b0ff0eb83700cbff98ad9a73 Mon Sep 17 00:00:00 2001 From: finn Date: Thu, 11 Sep 2025 16:55:36 +0200 Subject: [PATCH] Added KDTree support for KNN. Added tests/time benchmarks. --- benchmark_knn_direct.py | 361 +++++++++++++++++++++++ prior_depth_anything/__init__.py | 15 +- prior_depth_anything/depth_completion.py | 160 +++++++++- 3 files changed, 520 insertions(+), 16 deletions(-) create mode 100644 benchmark_knn_direct.py diff --git a/benchmark_knn_direct.py b/benchmark_knn_direct.py new file mode 100644 index 0000000..b7fe8da --- /dev/null +++ b/benchmark_knn_direct.py @@ -0,0 +1,361 @@ +#!/usr/bin/env python3 + +import torch +import numpy as np +import time +import sys +import os + +from prior_depth_anything.depth_completion import DepthCompletion + + +def validate_knn_distances(sparse_masks, complete_masks, knn_indices, reported_distances, method_name): + """ + Validate that reported distances match manually computed distances from knn_indices. + + Args: + sparse_masks: Boolean mask indicating sparse points + complete_masks: Boolean mask indicating points to complete + knn_indices: The KNN indices returned by the algorithm [num_complete, K] + reported_distances: The distances returned by the algorithm [num_complete, K] + method_name: Name for logging + + Returns: + bool: True if distances are valid, False otherwise + """ + # Get the device from the input tensors + device = sparse_masks.device + + # Extract coordinates the same way the KNN algorithms do + batch_sparse = torch.nonzero(sparse_masks, as_tuple=False)[..., [0, 2, 1]].float() + batch_complete = torch.nonzero(complete_masks, as_tuple=False)[..., [0, 2, 1]].float() + + x = batch_sparse[:, -2:].contiguous() # sparse coordinates [num_sparse, 2] + y = batch_complete[:, -2:].contiguous() # complete coordinates [num_complete, 2] + + # Ensure all tensors are on the same device + knn_indices = knn_indices.to(device) + reported_distances = reported_distances.to(device) + + # Vectorized distance computation + num_complete, K = knn_indices.shape + + # Get sparse points for each complete point using advanced indexing + # knn_indices: [num_complete, K] + # x[knn_indices]: [num_complete, K, 2] - coordinates of nearest neighbors + sparse_neighbors = x[knn_indices] # Shape: [num_complete, K, 2] + + # Expand complete points to match neighbor dimensions + y_expanded = y.unsqueeze(1).expand(-1, K, -1) # Shape: [num_complete, K, 2] + + # Compute distances vectorized: sqrt(sum((y - x)^2)) + manually_computed_distances = torch.norm(y_expanded - sparse_neighbors, dim=2) + + # Compare manually computed vs reported distances + distance_diff = torch.abs(manually_computed_distances - reported_distances) + max_diff = distance_diff.max().item() + mean_diff = distance_diff.mean().item() + + tolerance = 1e-6 + distances_valid = torch.allclose(manually_computed_distances, reported_distances, atol=tolerance, rtol=tolerance) + + print(f"{method_name}: Max diff = {max_diff:.8f}, Mean diff = {mean_diff:.8f}, Valid = {distances_valid}") + + return distances_valid + + +def create_synthetic_sparse_data(height=480, width=640, sparsity_ratio=0.99, device='cuda', batch_size=1): + """ + Create synthetic sparse disparity data for testing KNN algorithms. + """ + # Create synthetic dense disparities (simulating a complete depth map converted to disparity) + y_coords, x_coords = torch.meshgrid( + torch.linspace(0, 1, height, device=device), + torch.linspace(0, 1, width, device=device), + indexing='ij' + ) + # Create a realistic disparity pattern (inverse depth) + dense_disparities = 0.1 + 0.9 * (0.3 * x_coords + 0.7 * y_coords + + 0.1 * torch.sin(x_coords * 8) * torch.cos(y_coords * 8)) + dense_disparities = dense_disparities.unsqueeze(0).repeat(batch_size, 1, 1) + + # Create predicted disparities (slightly different from ground truth) + pred_disparities = dense_disparities + 0.02 * torch.randn_like(dense_disparities) + + # Create sparse version by randomly masking pixels + total_pixels = height * width + num_sparse_points = int(total_pixels * (1 - sparsity_ratio)) + + sparse_disparities = torch.zeros_like(dense_disparities) + sparse_masks = torch.zeros_like(dense_disparities, dtype=torch.bool) + complete_masks = torch.ones_like(dense_disparities, dtype=torch.bool) + + for b in range(batch_size): + # Randomly select pixels to keep as sparse points + flat_indices = torch.randperm(total_pixels, device=device)[:num_sparse_points] + row_indices = flat_indices // width + col_indices = flat_indices % width + + sparse_disparities[b, row_indices, col_indices] = dense_disparities[b, row_indices, col_indices] + sparse_masks[b, row_indices, col_indices] = True + complete_masks[b, row_indices, col_indices] = False # Don't complete sparse points + + print(f"Created synthetic data:") + print(f" Shape: {sparse_disparities.shape}") + print(f" Sparse points: {sparse_masks.sum().item()}") + print(f" Complete points: {complete_masks.sum().item()}") + print(f" Sparsity: {(sparse_masks == 0).float().mean().item()*100:.1f}% zeros") + + return { + 'sparse_disparities': sparse_disparities, + 'pred_disparities': pred_disparities, + 'sparse_masks': sparse_masks, + 'complete_masks': complete_masks + } + + +def create_dummy_depth_completion(): + """Create a DepthCompletion instance just to access the KNN methods.""" + class DummyArgs: + def __init__(self): + self.K = 5 + self.extra_condition = 'error' + self.normalize_confidence = False + self.frozen_model_size = 'vitb' + self.double_global = False + + args = DummyArgs() + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # We'll create the instance without initializing the depth model + depth_completion = DepthCompletion.__new__(DepthCompletion) + depth_completion.args = args + depth_completion.K = args.K + depth_completion.device = torch.device(device) + depth_completion.depth_model = None # We don't need this for KNN testing + + return depth_completion + + +def benchmark_knn_direct(sparsity_ratios=[0.90, 0.95, 0.99, 0.995], n_runs=5, K=5, height=480, width=640): + """ + Directly benchmark the KNN alignment functions. + + Args: + sparsity_ratios: List of sparsity ratios to test + n_runs: Number of benchmark runs per test + K: Number of nearest neighbors + height: Image height for testing + width: Image width for testing + """ + device = "cuda:0" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + print(f"K nearest neighbors: {K}") + print(f"Resolution: {height}x{width}") + + # Create depth completion instance for method access + depth_completion = create_dummy_depth_completion() + + results = {} + + print(f"\n{'='*80}") + print("BENCHMARKING KNN ALIGNMENT METHODS DIRECTLY") + print('='*80) + + for sparsity in sparsity_ratios: + print(f"\nšŸ“Š Testing sparsity: {sparsity*100:.1f}% zeros") + print("-" * 50) + + data = create_synthetic_sparse_data( + height=height, width=width, + sparsity_ratio=sparsity, + device=device + ) + + sparse_disparities = data['sparse_disparities'] + pred_disparities = data['pred_disparities'] + sparse_masks = data['sparse_masks'] + complete_masks = data['complete_masks'] + + # Test both KNN implementations + modes = ['KD-tree', 'torch_cluster'] + sparsity_results = {} + mode_outputs = {} + + for mode in modes: + use_kdtree = (mode == 'KD-tree') + mode_times = [] + + print(f"\n šŸ”„ Testing {mode} mode...") + + # Warm up + try: + _ = depth_completion.knn_aligns( + sparse_disparities=sparse_disparities, + pred_disparities=pred_disparities, + sparse_masks=sparse_masks, + complete_masks=complete_masks, + K=K, + kd_tree=use_kdtree, + return_indices=True + ) + print(f" āœ… Warm-up successful") + except Exception as e: + print(f" āŒ Warm-up failed: {e}") + continue + + # Benchmark runs + valid_runs = 0 + for run in range(n_runs): + torch.cuda.synchronize() if torch.cuda.is_available() else None + start_time = time.time() + + try: + dists, k_sparse_targets, k_pred_targets, knn_indices = depth_completion.knn_aligns( + sparse_disparities=sparse_disparities, + pred_disparities=pred_disparities, + sparse_masks=sparse_masks, + complete_masks=complete_masks, + K=K, + kd_tree=use_kdtree, + return_indices=True + ) + + torch.cuda.synchronize() if torch.cuda.is_available() else None + run_time = time.time() - start_time + mode_times.append(run_time) + valid_runs += 1 + + # Store output for the first run for comparison + if run == 0: + mode_outputs[mode] = { + 'dists': dists.cpu(), + 'k_sparse_targets': k_sparse_targets.cpu(), + 'k_pred_targets': k_pred_targets.cpu(), + 'knn_indices': knn_indices.cpu() + } + + print(f" Run {run+1}: {run_time:.6f}s") + + except Exception as e: + print(f" āŒ Run {run+1} failed: {e}") + import traceback + traceback.print_exc() + break + + if mode_times and valid_runs > 0: + avg_time = np.mean(mode_times) + std_time = np.std(mode_times) + sparsity_results[mode] = { + 'times': mode_times, + 'avg_time': avg_time, + 'std_time': std_time, + 'valid_runs': valid_runs + } + print(f" šŸ“ˆ Average ({valid_runs} runs): {avg_time:.6f} ± {std_time:.6f}s") + + # Compare outputs between methods + if len(mode_outputs) == 2: + print(f"\n šŸ” Comparing outputs...") + kdtree_out = mode_outputs['KD-tree'] + cluster_out = mode_outputs['torch_cluster'] + + # Compare distances (should be identical) + dists_diff = torch.abs(kdtree_out['dists'] - cluster_out['dists']) + print(f" Distances - Max diff: {dists_diff.max().item():.8f}, Mean diff: {dists_diff.mean().item():.8f}") + + # For sparse/pred targets, we need to sort each K-neighbor group to handle tie-breaking differences + # Each complete point has K neighbors, so reshape to [num_complete, K, ...] + def sort_knn_results(targets, indices): + """Sort K-neighbors for each complete point by their indices to ensure consistent ordering.""" + # targets and indices are [num_complete, K, ...] + sorted_indices = torch.argsort(indices, dim=1) # Sort by indices for consistency + sorted_targets = torch.gather(targets, 1, sorted_indices) + return sorted_targets + + # Overall match check + tolerance = 1e-6 + dists_match = torch.allclose(kdtree_out['dists'], cluster_out['dists'], atol=tolerance, rtol=tolerance) + dists_correct_kdtree = validate_knn_distances( + sparse_masks, complete_masks, + kdtree_out['knn_indices'], kdtree_out['dists'], 'KD-tree' + ) + dists_correct_cluster = validate_knn_distances( + sparse_masks, complete_masks, + cluster_out['knn_indices'], cluster_out['dists'], 'torch_cluster' + ) + + if dists_match and dists_correct_kdtree and dists_correct_cluster: + print(f" āœ… All outputs match within tolerance ({tolerance})") + else: + if not dists_match: + print(f" āŒ Distances do not match between methods!") + if not dists_correct_kdtree: + print(f" āŒ KD-tree distances do not match manually computed distances. Likely indexing errors.") + if not dists_correct_cluster: + print(f" āŒ torch_cluster distances do not match manually computed distances. Likely indexing errors.") + exit() + + results[sparsity] = sparsity_results + + # Print summary + print(f"\n{'='*80}") + print("PERFORMANCE SUMMARY") + print('='*80) + + for sparsity in sparsity_ratios: + if sparsity not in results: + continue + + print(f"\nSparsity {sparsity*100:.1f}% zeros:") + sparsity_results = results[sparsity] + + if 'KD-tree' in sparsity_results and 'torch_cluster' in sparsity_results: + kdtree_time = sparsity_results['KD-tree']['avg_time'] + torch_cluster_time = sparsity_results['torch_cluster']['avg_time'] + + print(f" KD-tree: {kdtree_time:.6f} ± {sparsity_results['KD-tree']['std_time']:.6f}s") + print(f" torch_cluster: {torch_cluster_time:.6f} ± {sparsity_results['torch_cluster']['std_time']:.6f}s") + + if kdtree_time < torch_cluster_time: + speedup = torch_cluster_time / kdtree_time + print(f" šŸ† KD-tree is {speedup:.2f}x faster") + else: + speedup = kdtree_time / torch_cluster_time + print(f" šŸ† torch_cluster is {speedup:.2f}x faster") + + return results + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='Benchmark KNN implementations directly') + parser.add_argument('--sparsity', type=float, nargs='+', + default=[0.60, 0.90], + help='Sparsity ratios to test') + parser.add_argument('--runs', type=int, default=5, + help='Number of runs per test') + parser.add_argument('--K', type=int, default=5, + help='Number of nearest neighbors') + parser.add_argument('--height', type=int, default=1280, + help='Image height for testing') + parser.add_argument('--width', type=int, default=1920, + help='Image width for testing') + + args = parser.parse_args() + + print("šŸš€ Direct KNN Benchmark") + print("="*80) + + try: + results = benchmark_knn_direct(args.sparsity, args.runs, args.K, args.height, args.width) + + print(f"\n{'='*80}") + print("āœ… BENCHMARK COMPLETE") + print('='*80) + + except Exception as e: + print(f"āŒ Benchmark failed: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/prior_depth_anything/__init__.py b/prior_depth_anything/__init__.py index 5a5547b..09ff3dc 100644 --- a/prior_depth_anything/__init__.py +++ b/prior_depth_anything/__init__.py @@ -107,7 +107,8 @@ def forward(self, cover_masks: torch.Tensor = None, prior_depths: torch.Tensor = None, geometric_depths: torch.Tensor = None, - pattern: Optional[str] = None + pattern: Optional[str] = None, + kd_tree: bool = True ): """ To facilitate further research, we batchify the forward process. """ ##### Coarse stage. ##### @@ -118,7 +119,8 @@ def forward(self, cover_masks=cover_masks, prior_depths=prior_depths, pattern=pattern, - geometric_depths=geometric_depths + geometric_depths=geometric_depths, + kd_tree=kd_tree ) # knn-aligned depths @@ -249,7 +251,8 @@ def infer_one_sample(self, double_global: bool = False, prior_cover: bool = False, visualize: bool = False, - down_fill_mode: str = 'linear' + down_fill_mode: str = 'linear', + kd_tree: bool = True ) -> torch.Tensor: """ Perform inference. Return the refined/completed depth. @@ -266,7 +269,8 @@ def infer_one_sample(self, pattern: The mode of prior-based additional sampling. It could be None. double_global: Whether to condition with two estimated depths or estimated + knn-map. prior_cover: Whether to keep all prior areas in knn-map, it functions when 'pattern' is not None. - visualize: Save results. + visualize: Save results. + kd_tree: Whether to use KD-tree or torch_cluster for KNN. Defaults to True. Example1: @@ -318,7 +322,8 @@ def infer_one_sample(self, sparse_masks=sparse_mask, cover_masks=cover_mask, pattern=pattern, - geometric_depths=geometric_depth + geometric_depths=geometric_depth, + kd_tree=kd_tree ) # (B, 1, H, W) ### Visualize the results. diff --git a/prior_depth_anything/depth_completion.py b/prior_depth_anything/depth_completion.py index 0c5b240..5da91a8 100644 --- a/prior_depth_anything/depth_completion.py +++ b/prior_depth_anything/depth_completion.py @@ -4,6 +4,7 @@ import warnings import time from typing import Dict, Tuple, Optional +from torch_kdtree import build_kd_tree from .utils import ( depth2disparity, @@ -116,7 +117,8 @@ def forward(self, prior_depths: Optional[torch.Tensor] = None, geometric_depths: Optional[torch.Tensor] = None, pattern: Optional[str] = None, - ret: str = 'all' # ret = 'knn' or 'global' + ret: str = 'all', # ret = 'knn' or 'global' + kd_tree: bool = True ) -> Dict[str, torch.Tensor]: """ Processe input images and sparse depth information to produce completed depth maps. @@ -129,6 +131,8 @@ def forward(self, cover_masks (torch.Tensor, optional): Indicating areas to be covered by prior depth. prior_depths (torch.Tensor, optional): Prior depth information for covering large areas. pattern (optional): Pattern for sampling sparse depth points. + ret (str): Return type - 'knn', 'global', or 'all'. Defaults to 'all'. + kd_tree (bool): Whether to use KD-tree or torch_cluster for KNN. Defaults to True. Returns: Dict[str, torch.Tensor]: Containing the processed data, including: @@ -176,6 +180,7 @@ def forward(self, sparse_masks=sparse_masks, complete_masks=complete_masks, K=self.K, + kd_tree=kd_tree ) """ @@ -320,15 +325,16 @@ def perform_weighted(self, sparse_weighted = W @ sparse_ori.unsqueeze(-1) return sparse_weighted, pred_weighted - def knn_aligns(self, + def _knn_aligns_torch_cluster(self, sparse_disparities: torch.Tensor, pred_disparities: torch.Tensor, sparse_masks: torch.Tensor, complete_masks: torch.Tensor, - K: int + K: int, + return_indices: bool = False ) -> Tuple[torch.Tensor, ...]: """ - Perform K-Nearest Neighbors (KNN) alignment on sparse and predicted disparities. + Perform K-Nearest Neighbors (KNN) alignment using torch_cluster.knn. Args: sparse_disparities (torch.Tensor): Disparities for sparse map points. @@ -343,7 +349,6 @@ def knn_aligns(self, - k_sparse_targets: Disparities of the K nearest neighbors from the sparse data. - k_pred_targets: Disparities of the K nearest neighbors from the predicted data. """ - # Coordinates are processed to ensure compatibility with the KNN function. batch_sparse = torch.nonzero(sparse_masks, as_tuple=False)[..., [0, 2, 1]].float() # [N, 3] (b, x, y) batch_complete = torch.nonzero(complete_masks, as_tuple=False)[..., [0, 2, 1]].float() # [M, 3] (b, x, y) @@ -362,20 +367,151 @@ def knn_aligns(self, knn_coords = x[knn_indices] expanded_complete_points = y.unsqueeze(dim=1).repeat(1, K, 1) dists = torch.norm(expanded_complete_points - knn_coords, dim=2) + # print(knn_indices[-1]) + if return_indices: + return dists, k_sparse_targets, k_pred_targets, knn_indices + else: + return dists, k_sparse_targets, k_pred_targets + + + def _knn_aligns_kdtree(self, + sparse_disparities: torch.Tensor, + pred_disparities: torch.Tensor, + sparse_masks: torch.Tensor, + complete_masks: torch.Tensor, + K: int, + return_indices: bool = False + ) -> Tuple[torch.Tensor, ...]: + """ + Perform K-Nearest Neighbors (KNN) alignment using KD-tree (torch_kdtree). + """ + # Use the same coordinate extraction as torch_cluster + batch_sparse = torch.nonzero(sparse_masks, as_tuple=False)[..., [0, 2, 1]].float() + batch_complete = torch.nonzero(complete_masks, as_tuple=False)[..., [0, 2, 1]].float() + + batch_x, batch_y = batch_sparse[:, 0].contiguous(), batch_complete[:, 0].contiguous() + x, y = batch_sparse[:, -2:].contiguous(), batch_complete[:, -2:].contiguous() + + # Initialize outputs with the correct shape + num_complete = y.shape[0] + knn_indices = torch.zeros((num_complete, K), dtype=torch.long, device=x.device) + dists = torch.zeros((num_complete, K), dtype=torch.float32, device=x.device) + + # Get unique batch indices to process each batch separately + unique_batches = torch.unique(torch.cat([batch_x, batch_y])) + + # Track the current position in the output arrays + complete_offset = 0 + + for batch_idx in unique_batches: + # Find points belonging to this batch + sparse_in_batch = batch_x == batch_idx + complete_in_batch = batch_y == batch_idx + + if not sparse_in_batch.any() or not complete_in_batch.any(): + continue + + # Extract coordinates for this batch + x_batch = x[sparse_in_batch] + y_batch = y[complete_in_batch] + + # Get the number of points in this batch + num_complete_batch = y_batch.shape[0] + + if x_batch.shape[0] == 0 or num_complete_batch == 0: + complete_offset += num_complete_batch + continue + + # Build KD-tree for sparse points + try: + kdtree = build_kd_tree(x_batch) + batch_dists, batch_indices = kdtree.query(y_batch, nr_nns_searches=K) + + # Convert squared distances to actual distances + batch_dists = torch.sqrt(torch.clamp(batch_dists, min=0)) + + # Map local indices to global indices + # Get the global indices of sparse points in this batch + global_sparse_indices = torch.where(sparse_in_batch)[0] + + # Convert local KD-tree indices to global indices + # batch_indices: [num_complete_batch, K] + # global_sparse_indices: [num_sparse_batch] + global_batch_indices = global_sparse_indices[batch_indices] + + # Find where to place these results in the global output + complete_positions = torch.where(complete_in_batch)[0] + + # Store results + knn_indices[complete_positions] = global_batch_indices + dists[complete_positions] = batch_dists + + except Exception as e: + print(f"Error processing batch {batch_idx}: {e}") + # Fill with fallback values if needed + complete_positions = torch.where(complete_in_batch)[0] + if len(complete_positions) > 0: + # Use first sparse point as fallback + first_sparse_idx = torch.where(sparse_in_batch)[0][0] if sparse_in_batch.any() else 0 + knn_indices[complete_positions] = first_sparse_idx + dists[complete_positions] = 1.0 # Large distance as fallback + + # Extract the disparity values using the computed indices + k_sparse_targets = sparse_disparities[sparse_masks][knn_indices] + k_pred_targets = pred_disparities[sparse_masks][knn_indices] + + if return_indices: + return dists, k_sparse_targets, k_pred_targets, knn_indices + else: + return dists, k_sparse_targets, k_pred_targets + + + def knn_aligns(self, + sparse_disparities: torch.Tensor, + pred_disparities: torch.Tensor, + sparse_masks: torch.Tensor, + complete_masks: torch.Tensor, + K: int, + kd_tree: bool = True, + return_indices: bool = False + ) -> Tuple[torch.Tensor, ...]: + """ + Perform K-Nearest Neighbors (KNN) alignment on sparse and predicted disparities. + Dispatches to either KD-tree or torch_cluster implementation. - return dists, k_sparse_targets, k_pred_targets + Args: + sparse_disparities (torch.Tensor): Disparities for sparse map points. + pred_disparities (torch.Tensor): Predicted disparities for sparse map points. + sparse_masks (torch.Tensor): Indicating which points in the sparse map are valid. + complete_masks (torch.Tensor): Indicating which points in the map to be completed. + K (int): The number of nearest neighbors to find for each map point. + kd_tree (bool): If True, use KD-tree; if False, use torch_cluster.knn. + + Returns: + Tuple: Containing three tensors: + - dists: The Euclidean distances from each sparse point to its K nearest neighbors. + - k_sparse_targets: Disparities of the K nearest neighbors from the sparse data. + - k_pred_targets: Disparities of the K nearest neighbors from the predicted data. + """ + if kd_tree: + return self._knn_aligns_kdtree( + sparse_disparities, pred_disparities, sparse_masks, complete_masks, K, return_indices=return_indices + ) + else: + return self._knn_aligns_torch_cluster( + sparse_disparities, pred_disparities, sparse_masks, complete_masks, K, return_indices=return_indices + ) def kss_completer(self, sparse_disparities: torch.Tensor, pred_disparities: torch.Tensor, complete_masks: torch.Tensor, sparse_masks: torch.Tensor, - K: int = 5 + K: int = 5, + kd_tree: bool = True, ) -> torch.Tensor: """ - Perform K-Nearest Neighbors (KNN) interpolation to complete sparse disparities.Use a batch-oriented - implementation of KNN interpolation to complete the sparse disparities. We leverages "torch_cluster.knn" - for acceleration and GPU memory efficiency. + Perform K-Nearest Neighbors (KNN) interpolation to complete sparse disparities. Args: sparse_disparities (torch.Tensor): Disparities for sparse map. @@ -383,6 +519,7 @@ def kss_completer(self, complete_masks (torch.Tensor): Indicating which points in the complete map are valid. sparse_masks (torch.Tensor): Indicating which points in the sparse map are valid. K (int): The number of nearest neighbors to use for interpolation. Defaults to 5. + kd_tree (bool): Whether to use KD-tree or torch_cluster for KNN. Defaults to True. Returns: The completed disparities, interpolated from the nearest neighbors. @@ -394,7 +531,8 @@ def kss_completer(self, pred_disparities=pred_disparities, sparse_masks=sparse_masks, complete_masks=complete_masks, - K=K + K=K, + kd_tree=kd_tree ) scaled_preds = torch.zeros_like(sparse_disparities, device=self.device, dtype=torch.float32)