From 07992de09eb0c7cdcb5c9dffd5b33fc1d33c4fdc Mon Sep 17 00:00:00 2001 From: Chriszkxxx Date: Tue, 17 Mar 2026 17:09:15 +0800 Subject: [PATCH] fix_cut3r --- src/openworldlib/operators/cut3r_operator.py | 583 ++++++++++-------- .../pipelines/cut3r/pipeline_cut3r.py | 235 ++----- .../cut3r/cut3r_representation.py | 343 ++++++++++- test/test_cut3r.py | 13 +- 4 files changed, 739 insertions(+), 435 deletions(-) diff --git a/src/openworldlib/operators/cut3r_operator.py b/src/openworldlib/operators/cut3r_operator.py index 61f6ea21..d9a00d0d 100644 --- a/src/openworldlib/operators/cut3r_operator.py +++ b/src/openworldlib/operators/cut3r_operator.py @@ -1,246 +1,337 @@ -import os -import cv2 -import numpy as np -import torch -from typing import List, Optional, Union, Dict, Any -from pathlib import Path - -from .base_operator import BaseOperator - - -class CUT3ROperator(BaseOperator): - """Operator for CUT3R pipeline utilities.""" - - def __init__( - self, - operation_types=["visual_instruction"], - interaction_template=[ - "image_3d", - "video_3d", - "point_cloud", - "depth_map", - "camera_pose", - "move_left", - "move_right", - "move_up", - "move_down", - "zoom_in", - "zoom_out", - "rotate_left", - "rotate_right" - ] - ): - """ - Initialize CUT3R operator. - - Args: - operation_types: List of operation types - interaction_template: List of valid interaction types - - "image_3d": Process single image for 3D reconstruction - - "video_3d": Process video for 3D reconstruction - - "point_cloud": Generate point cloud output - - "depth_map": Generate depth map output - - "camera_pose": Estimate camera poses - - "move_left/right/up/down": Camera movement controls - - "zoom_in/out": Camera zoom controls - - "rotate_left/right": Camera rotation controls - """ - super(CUT3ROperator, self).__init__(operation_types=operation_types) - self.interaction_template = interaction_template - self.interaction_template_init() - - def collect_paths(self, path: Union[str, Path]) -> List[str]: - """ - Collect file paths from a file, directory, or txt list. - - Args: - path: File path, directory path, or txt file containing paths - - Returns: - List of file paths - """ - path = str(path) - if os.path.isfile(path): - if path.lower().endswith(".txt"): - with open(path, "r", encoding="utf-8") as handle: - files = [line.strip() for line in handle.readlines() if line.strip()] - else: - files = [path] - else: - files = [ - os.path.join(path, name) - for name in os.listdir(path) - if not name.startswith(".") and name.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')) - ] - files.sort() - return files - - def process_perception( - self, - input_signal: Union[str, np.ndarray, torch.Tensor, List[str], List[np.ndarray]] - ) -> Union[np.ndarray, List[np.ndarray]]: - """ - Process visual signal (image/video) for real-time interactive updates. - This function handles loading and preprocessing of images from various input types. - - Args: - input_signal: Visual input signal - can be: - - Image file path (str) - - List of image file paths (List[str]) - - Numpy array (H, W, 3) in RGB or BGR format - - List of numpy arrays - - Torch tensor (C, H, W) or (1, C, H, W) in CHW format - - Returns: - Preprocessed RGB image array(s) (normalized to [0, 1]) with shape (H, W, 3) - or list of such arrays - - Raises: - ValueError: If image cannot be loaded or processed - """ - # Handle list inputs - if isinstance(input_signal, list): - return [self.process_perception(item) for item in input_signal] - - # Handle single input - if isinstance(input_signal, torch.Tensor): - # Assume tensor is in CHW format, convert to numpy - if input_signal.dim() == 3: - image_rgb = input_signal.permute(1, 2, 0).cpu().numpy() - else: - image_rgb = input_signal[0].permute(1, 2, 0).cpu().numpy() - if image_rgb.max() > 1.0: - image_rgb = image_rgb / 255.0 - elif isinstance(input_signal, np.ndarray): - image_rgb = input_signal / 255.0 if input_signal.max() > 1.0 else input_signal - # Convert BGR to RGB if needed (heuristic: if first channel mean > last channel mean) - if len(image_rgb.shape) == 3 and image_rgb.shape[2] == 3: - if image_rgb[..., 0].mean() > image_rgb[..., 2].mean(): - image_rgb = image_rgb[..., ::-1] - else: - # String path: support single image, directory, or txt list. - if isinstance(input_signal, (str, Path)): - input_path = str(input_signal) - if os.path.isdir(input_path) or (os.path.isfile(input_path) and input_path.lower().endswith(".txt")): - file_list = self.collect_paths(input_path) - if len(file_list) == 0: - raise ValueError(f"No valid image files found in {input_path}") - return [self.process_perception(p) for p in file_list] - - raw_image = cv2.imread(input_path) - if raw_image is None: - raise ValueError(f"Could not read image from {input_signal}") - image_rgb = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 - else: - raise ValueError(f"Unsupported input type for process_perception: {type(input_signal)}") - - return image_rgb - - def check_interaction(self, interaction): - """ - Check if interaction is in the interaction template. - - Args: - interaction: Interaction string to check - - Returns: - True if interaction is valid - - Raises: - ValueError: If interaction is not in template - """ - if interaction not in self.interaction_template: - raise ValueError(f"Interaction '{interaction}' not in interaction_template. " - f"Available interactions: {self.interaction_template}") - return True - - def get_interaction(self, interaction): - """ - Add interaction to current_interaction list after validation. - - Args: - interaction: Interaction string to add - """ - self.check_interaction(interaction) - self.current_interaction.append(interaction) - - def process_interaction(self, num_frames: Optional[int] = None) -> Dict[str, Any]: - """ - Process current interactions and convert to features for representation/synthesis. - - Args: - num_frames: Number of frames (for video processing, optional) - - Returns: - Dictionary containing processed interaction features: - - data_type: "image" or "video" - - output_type: "point_cloud", "depth_map", "camera_pose", or "all" - - camera_control: Dict with camera movement parameters (if applicable) - """ - if len(self.current_interaction) == 0: - raise ValueError("No interaction to process. Use get_interaction() first.") - - # Get the latest interaction - latest_interaction = self.current_interaction[-1] - self.interaction_history.append(latest_interaction) - - # Process interaction based on type - result = { - "data_type": "image", - "output_type": "all", # point_cloud, depth_map, camera_pose, or all - "camera_control": None - } - - # Data type interactions - if latest_interaction == "image_3d": - result["data_type"] = "image" - result["output_type"] = "all" - elif latest_interaction == "video_3d": - result["data_type"] = "video" - result["output_type"] = "all" - elif latest_interaction == "point_cloud": - result["data_type"] = "image" - result["output_type"] = "point_cloud" - elif latest_interaction == "depth_map": - result["data_type"] = "image" - result["output_type"] = "depth_map" - elif latest_interaction == "camera_pose": - result["data_type"] = "image" - result["output_type"] = "camera_pose" - - # Camera control interactions - elif latest_interaction in ["move_left", "move_right", "move_up", "move_down"]: - direction_map = { - "move_left": {"x": -0.1, "y": 0, "z": 0}, - "move_right": {"x": 0.1, "y": 0, "z": 0}, - "move_up": {"x": 0, "y": 0.1, "z": 0}, - "move_down": {"x": 0, "y": -0.1, "z": 0}, - } - result["camera_control"] = direction_map[latest_interaction] - elif latest_interaction in ["zoom_in", "zoom_out"]: - zoom_map = { - "zoom_in": {"scale": 1.1}, - "zoom_out": {"scale": 0.9}, - } - result["camera_control"] = zoom_map[latest_interaction] - elif latest_interaction in ["rotate_left", "rotate_right"]: - rotation_map = { - "rotate_left": {"angle": -10}, - "rotate_right": {"angle": 10}, - } - result["camera_control"] = rotation_map[latest_interaction] - - # Add num_frames if provided (for video processing) - if num_frames is not None: - result["num_frames"] = num_frames - - return result - - def delete_last_interaction(self): - """Delete the last interaction from current_interaction list.""" - if len(self.current_interaction) > 0: - self.current_interaction = self.current_interaction[:-1] - else: - raise ValueError("No interaction to delete.") - - +import os +import cv2 +import numpy as np +import torch +from typing import List, Optional, Union, Dict, Any +from pathlib import Path +from PIL import Image + +from .base_operator import BaseOperator + + +class CUT3ROperator(BaseOperator): + """Operator for CUT3R pipeline utilities.""" + + def __init__( + self, + operation_types=["visual_instruction"], + interaction_template=[ + "image_3d", + "video_3d", + "point_cloud", + "depth_map", + "camera_pose", + "move_left", + "move_right", + "move_up", + "move_down", + "zoom_in", + "zoom_out", + "rotate_left", + "rotate_right" + ] + ): + """ + Initialize CUT3R operator. + + Args: + operation_types: List of operation types + interaction_template: List of valid interaction types + - "image_3d": Process single image for 3D reconstruction + - "video_3d": Process video for 3D reconstruction + - "point_cloud": Generate point cloud output + - "depth_map": Generate depth map output + - "camera_pose": Estimate camera poses + - "move_left/right/up/down": Camera movement controls + - "zoom_in/out": Camera zoom controls + - "rotate_left/right": Camera rotation controls + """ + super(CUT3ROperator, self).__init__(operation_types=operation_types) + self.interaction_template = interaction_template + self.interaction_template_init() + + def collect_paths(self, path: Union[str, Path]) -> List[str]: + """ + Collect file paths from a file, directory, or txt list. + + Args: + path: File path, directory path, or txt file containing paths + + Returns: + List of file paths + """ + path = str(path) + if os.path.isfile(path): + if path.lower().endswith(".txt"): + with open(path, "r", encoding="utf-8") as handle: + files = [line.strip() for line in handle.readlines() if line.strip()] + else: + files = [path] + else: + files = [ + os.path.join(path, name) + for name in os.listdir(path) + if not name.startswith(".") and name.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')) + ] + files.sort() + return files + + def process_perception( + self, + input_signal: Union[str, np.ndarray, torch.Tensor, Image.Image, List[str], List[np.ndarray], List[Image.Image]] + ) -> Union[np.ndarray, List[np.ndarray]]: + """ + Process visual signal (image/video) for real-time interactive updates. + This function handles loading and preprocessing of images from various input types. + + Args: + input_signal: Visual input signal - can be: + - Image file path (str) + - List of image file paths (List[str]) + - Numpy array (H, W, 3) in RGB or BGR format + - List of numpy arrays + - Torch tensor (C, H, W) or (1, C, H, W) in CHW format + + Returns: + Preprocessed RGB image array(s) (normalized to [0, 1]) with shape (H, W, 3) + or list of such arrays + + Raises: + ValueError: If image cannot be loaded or processed + """ + # Handle list inputs (paths, numpy arrays, tensors, PIL Images) + if isinstance(input_signal, list): + return [self.process_perception(item) for item in input_signal] + + # Handle single input + if isinstance(input_signal, Image.Image): + image_rgb = np.array(input_signal) + if image_rgb.dtype != np.float32: + image_rgb = image_rgb.astype(np.float32) + if image_rgb.max() > 1.0: + image_rgb = image_rgb / 255.0 + elif isinstance(input_signal, torch.Tensor): + # Assume tensor is in CHW format, convert to numpy + if input_signal.dim() == 3: + image_rgb = input_signal.permute(1, 2, 0).cpu().numpy() + else: + image_rgb = input_signal[0].permute(1, 2, 0).cpu().numpy() + if image_rgb.max() > 1.0: + image_rgb = image_rgb / 255.0 + elif isinstance(input_signal, np.ndarray): + image_rgb = input_signal / 255.0 if input_signal.max() > 1.0 else input_signal + # Convert BGR to RGB if needed (heuristic: if first channel mean > last channel mean) + if len(image_rgb.shape) == 3 and image_rgb.shape[2] == 3: + if image_rgb[..., 0].mean() > image_rgb[..., 2].mean(): + image_rgb = image_rgb[..., ::-1] + else: + # String path: support single image, directory, or txt list. + if isinstance(input_signal, (str, Path)): + input_path = str(input_signal) + if os.path.isdir(input_path) or (os.path.isfile(input_path) and input_path.lower().endswith(".txt")): + file_list = self.collect_paths(input_path) + if len(file_list) == 0: + raise ValueError(f"No valid image files found in {input_path}") + return [self.process_perception(p) for p in file_list] + + raw_image = cv2.imread(input_path) + if raw_image is None: + raise ValueError(f"Could not read image from {input_signal}") + image_rgb = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 + else: + raise ValueError(f"Unsupported input type for process_perception: {type(input_signal)}") + + return image_rgb + + def check_interaction(self, interaction): + """ + Check if interaction is in the interaction template. + + Args: + interaction: Interaction string to check + + Returns: + True if interaction is valid + + Raises: + ValueError: If interaction is not in template + """ + if interaction not in self.interaction_template: + raise ValueError(f"Interaction '{interaction}' not in interaction_template. " + f"Available interactions: {self.interaction_template}") + return True + + def get_interaction(self, interaction): + """ + Add interaction to current_interaction list after validation. + + Args: + interaction: Interaction string to add + """ + self.check_interaction(interaction) + self.current_interaction.append(interaction) + + def process_interaction(self, num_frames: Optional[int] = None) -> Dict[str, Any]: + """ + Process current interactions and convert to features for representation/synthesis. + + Args: + num_frames: Number of frames (for video processing, optional) + + Returns: + Dictionary containing processed interaction features: + - data_type: "image" or "video" + - output_type: "point_cloud", "depth_map", "camera_pose", or "all" + - camera_control: Dict with camera movement parameters (if applicable) + """ + if len(self.current_interaction) == 0: + raise ValueError("No interaction to process. Use get_interaction() first.") + + # Get the latest interaction + latest_interaction = self.current_interaction[-1] + self.interaction_history.append(latest_interaction) + + # Process interaction based on type + result = { + "data_type": "image", + "output_type": "all", # point_cloud, depth_map, camera_pose, or all + "camera_control": None + } + + # Data type interactions + if latest_interaction == "image_3d": + result["data_type"] = "image" + result["output_type"] = "all" + elif latest_interaction == "video_3d": + result["data_type"] = "video" + result["output_type"] = "all" + elif latest_interaction == "point_cloud": + result["data_type"] = "image" + result["output_type"] = "point_cloud" + elif latest_interaction == "depth_map": + result["data_type"] = "image" + result["output_type"] = "depth_map" + elif latest_interaction == "camera_pose": + result["data_type"] = "image" + result["output_type"] = "camera_pose" + + # Camera control interactions + elif latest_interaction in ["move_left", "move_right", "move_up", "move_down"]: + direction_map = { + "move_left": {"x": -0.1, "y": 0, "z": 0}, + "move_right": {"x": 0.1, "y": 0, "z": 0}, + "move_up": {"x": 0, "y": 0.1, "z": 0}, + "move_down": {"x": 0, "y": -0.1, "z": 0}, + } + result["camera_control"] = direction_map[latest_interaction] + elif latest_interaction in ["zoom_in", "zoom_out"]: + zoom_map = { + "zoom_in": {"scale": 1.1}, + "zoom_out": {"scale": 0.9}, + } + result["camera_control"] = zoom_map[latest_interaction] + elif latest_interaction in ["rotate_left", "rotate_right"]: + rotation_map = { + "rotate_left": {"angle": -10}, + "rotate_right": {"angle": 10}, + } + result["camera_control"] = rotation_map[latest_interaction] + + # Add num_frames if provided (for video processing) + if num_frames is not None: + result["num_frames"] = num_frames + + return result + + def delete_last_interaction(self): + """Delete the last interaction from current_interaction list.""" + if len(self.current_interaction) > 0: + self.current_interaction = self.current_interaction[:-1] + else: + raise ValueError("No interaction to delete.") + + @staticmethod + def normalize_interaction_sequence( + interaction: Optional[Union[str, List[str]]] + ) -> List[str]: + """ + Normalize interaction input to a flat list of strings. + Supports None, single string, or list of strings. + """ + if interaction is None: + return [] + if isinstance(interaction, str): + return [interaction] + return [str(sig) for sig in interaction if str(sig).strip()] + + @staticmethod + def apply_interaction_to_camera( + camera_cfg: Dict[str, Any], + interaction: str, + camera_range: Dict[str, Any], + yaw_step: float = 30.0, + pitch_step: float = 20.0, + zoom_factor: float = 0.6, + ) -> Dict[str, Any]: + """ + Update a simple (radius, yaw, pitch) camera configuration according to a + high-level interaction signal, clamped by camera_range. + Only supports the unified 3D interaction schema + (forward/backward/left/right, forward_left, camera_l, camera_zoom_in, ...). + """ + yaw = float(camera_cfg.get("yaw", 0.0)) + pitch = float(camera_cfg.get("pitch", 0.0)) + radius = float(camera_cfg.get("radius", 4.0)) + sig = interaction.strip().lower() + + # Yaw (left/right) + if sig in ["left", "camera_l"]: + yaw -= yaw_step + elif sig in ["right", "camera_r"]: + yaw += yaw_step + elif sig == "camera_ul": + yaw -= yaw_step + pitch += pitch_step + elif sig == "camera_ur": + yaw += yaw_step + pitch += pitch_step + elif sig == "camera_dl": + yaw -= yaw_step + pitch -= pitch_step + elif sig == "camera_dr": + yaw += yaw_step + pitch -= pitch_step + # Pitch (up/down) + elif sig == "camera_up": + pitch += pitch_step + elif sig == "camera_down": + pitch -= pitch_step + # Radius (forward/backward, zoom) + elif sig in ["forward", "camera_zoom_in"]: + radius *= zoom_factor + elif sig in ["backward", "camera_zoom_out"]: + radius /= zoom_factor + elif sig == "forward_left": + yaw -= yaw_step + radius *= zoom_factor + elif sig == "forward_right": + yaw += yaw_step + radius *= zoom_factor + elif sig == "backward_left": + yaw -= yaw_step + radius /= zoom_factor + elif sig == "backward_right": + yaw += yaw_step + radius /= zoom_factor + + yaw = max(camera_range["yaw_min"], min(camera_range["yaw_max"], yaw)) + pitch = max(camera_range["pitch_min"], min(camera_range["pitch_max"], pitch)) + radius = max(camera_range["radius_min"], min(camera_range["radius_max"], radius)) + + camera_cfg["yaw"] = yaw + camera_cfg["pitch"] = pitch + camera_cfg["radius"] = radius + + return camera_cfg + + diff --git a/src/openworldlib/pipelines/cut3r/pipeline_cut3r.py b/src/openworldlib/pipelines/cut3r/pipeline_cut3r.py index 04431fcc..8dbf83c9 100644 --- a/src/openworldlib/pipelines/cut3r/pipeline_cut3r.py +++ b/src/openworldlib/pipelines/cut3r/pipeline_cut3r.py @@ -415,8 +415,9 @@ def reconstruct_ply( dists = np.linalg.norm(all_points - center[None, :], axis=1) radius = float(dists.max() + 1e-6) - radius_min = max(radius * 0.5, 1e-3) - radius_max = radius * 3.0 + # Allow camera to move noticeably closer/farther during forward/backward interactions. + radius_min = max(radius * 0.2, 1e-3) + radius_max = radius * 4.0 camera_range = { "center": center.tolist(), @@ -440,7 +441,7 @@ def reconstruct_ply( "camera_range": camera_range, "default_camera": default_camera, } - + @staticmethod def _preprocess_point_cloud_for_render( points: np.ndarray, @@ -516,125 +517,19 @@ def render_with_3dgs( far_plane: float = 1000.0, ) -> Image.Image: """ - Stage 2: Render a view from the reconstructed PLY using a 3D Gaussian - Splatting renderer. - - Args: - ply_path: Path to the reconstructed point cloud PLY. - camera_config: Dictionary describing the camera, with keys: - - 'center': list of 3 floats, scene center - - 'radius': float, camera distance to center - - 'yaw': float, yaw angle in degrees (around Y axis) - - 'pitch': float, pitch angle in degrees (around X axis) - image_width: Output image width. - image_height: Output image height. - device: Torch device to use. Defaults to 'cuda' if available, else 'cpu'. - near_plane: Near plane distance for rendering. - far_plane: Far plane distance for rendering. - - Returns: - A PIL.Image with the rendered view. + Thin wrapper: delegate 3DGS rendering to CUT3RRepresentation. """ - from ...base_models.three_dimensions.point_clouds.gaussian_splatting.scene.dataset_readers import ( - fetchPly, - ) - - device = device or ("cuda" if torch.cuda.is_available() else "cpu") - - pcd = fetchPly(ply_path) - points = np.asarray(pcd.points, dtype=np.float32) - colors = np.asarray(pcd.colors, dtype=np.float32) - - if points.size == 0: - raise RuntimeError(f"No points loaded from PLY: {ply_path}") - - center = np.asarray(camera_config.get("center", points.mean(axis=0)), dtype=np.float32) - points, colors = self._preprocess_point_cloud_for_render(points, colors, center) - if points.size == 0: - raise RuntimeError("Point cloud is empty after preprocessing for rendering.") - - radius = float(camera_config.get("radius", 1.5 * np.linalg.norm(points - center[None, :], axis=1).max())) - yaw_deg = float(camera_config.get("yaw", 0.0)) - pitch_deg = float(camera_config.get("pitch", 0.0)) - - yaw = np.deg2rad(yaw_deg) - pitch = np.deg2rad(pitch_deg) - - cam_x = center[0] + radius * np.cos(pitch) * np.sin(yaw) - cam_y = center[1] + radius * np.sin(pitch) - cam_z = center[2] + radius * np.cos(pitch) * np.cos(yaw) - cam_pos = np.array([cam_x, cam_y, cam_z], dtype=np.float32) - - forward = center - cam_pos - forward = forward / (np.linalg.norm(forward) + 1e-8) - up = np.array([0.0, 1.0, 0.0], dtype=np.float32) - - right = np.cross(forward, up) - right = right / (np.linalg.norm(right) + 1e-8) - up = np.cross(right, forward) - up = up / (np.linalg.norm(up) + 1e-8) - - c2w = np.eye(4, dtype=np.float32) - c2w[0, :3] = right - c2w[1, :3] = up - c2w[2, :3] = forward - c2w[:3, 3] = cam_pos - - fx = 0.5 * image_width / np.tan(np.deg2rad(60.0) / 2.0) - fy = 0.5 * image_height / np.tan(np.deg2rad(45.0) / 2.0) - cx = image_width / 2.0 - cy = image_height / 2.0 - - sh_degree = 0 - - xyz = torch.from_numpy(points).to(device=device, dtype=torch.float32) - scale_value = self._estimate_gaussian_scale(points, center) - scale = torch.full((xyz.shape[0], 3), scale_value, device=device, dtype=torch.float32) - - rotation = torch.zeros((xyz.shape[0], 4), device=device, dtype=torch.float32) - rotation[:, 0] = 1.0 - - # Align closer to CUT3R's gsplat usage (high opacity, small gaussian scale). - opacity = torch.full((xyz.shape[0], 1), 0.95, device=device, dtype=torch.float32) - - color_tensor = torch.from_numpy(np.clip(colors, 0.0, 1.0)).to(device=device, dtype=torch.float32) - features = color_tensor - - gaussian_params = torch.cat( - [xyz, opacity, scale, rotation, features], - dim=-1, - ).unsqueeze(0) - - test_c2ws = torch.from_numpy(c2w).unsqueeze(0).unsqueeze(0).to(device=device, dtype=torch.float32) - intr = torch.tensor([[fx, fy, cx, cy]], dtype=torch.float32, device=device).unsqueeze(0) - - rgb, _ = gaussian_render( - gaussian_params, - test_c2ws, - intr, - image_width, - image_height, + if self.representation_model is None: + raise RuntimeError("Representation model not loaded. Use from_pretrained() first.") + return self.representation_model.render_with_3dgs( + ply_path=ply_path, + camera_config=camera_config, + image_width=image_width, + image_height=image_height, + device=device, near_plane=near_plane, far_plane=far_plane, - use_checkpoint=False, - sh_degree=sh_degree, - bg_mode='white', ) - - # gaussian_render returns rgb in shape (B, V, 3, H, W) - # Use the first batch and first view to form an RGB frame (H, W, 3). - rgb_img = rgb[0, 0] - rgb_img = rgb_img.clamp(-1.0, 1.0).add(1.0).div(2.0) - rgb_np = ( - rgb_img.mul(255.0) - .permute(1, 2, 0) - .detach() - .cpu() - .numpy() - .astype(np.uint8) - ) - - return Image.fromarray(rgb_np) def render_orbit_video_with_3dgs( self, @@ -695,46 +590,6 @@ def render_orbit_video_with_3dgs( return frames - @staticmethod - def _apply_interaction_to_camera( - camera_cfg: Dict[str, Any], - interaction: str, - camera_range: Dict[str, Any], - yaw_step: float = 10.0, - pitch_step: float = 7.5, - zoom_factor: float = 0.9, - ) -> Dict[str, Any]: - """ - Update a simple (radius, yaw, pitch) camera configuration according to a - high-level interaction signal, clamped by camera_range. - """ - yaw = float(camera_cfg.get("yaw", 0.0)) - pitch = float(camera_cfg.get("pitch", 0.0)) - radius = float(camera_cfg.get("radius", 4.0)) - - if interaction in ["move_left", "rotate_left"]: - yaw -= yaw_step - elif interaction in ["move_right", "rotate_right"]: - yaw += yaw_step - elif interaction == "move_up": - pitch += pitch_step - elif interaction == "move_down": - pitch -= pitch_step - elif interaction == "zoom_in": - radius *= zoom_factor - elif interaction == "zoom_out": - radius /= zoom_factor - - yaw = max(camera_range["yaw_min"], min(camera_range["yaw_max"], yaw)) - pitch = max(camera_range["pitch_min"], min(camera_range["pitch_max"], pitch)) - radius = max(camera_range["radius_min"], min(camera_range["radius_max"], radius)) - - camera_cfg["yaw"] = yaw - camera_cfg["pitch"] = pitch - camera_cfg["radius"] = radius - - return camera_cfg - def render_interaction_video_with_3dgs( self, ply_path: str, @@ -765,10 +620,10 @@ def render_interaction_video_with_3dgs( } for sig in interaction_sequence: - camera_cfg = self._apply_interaction_to_camera( - camera_cfg, - sig, - camera_range, + camera_cfg = self.operator.apply_interaction_to_camera( + camera_cfg=camera_cfg, + interaction=sig, + camera_range=camera_range, ) img = self.render_with_3dgs( ply_path=ply_path, @@ -786,8 +641,9 @@ def render_interaction_video_with_3dgs( def run_two_stage_3dgs_video( self, - data_path: Union[str, Image.Image, np.ndarray, List[str], List[Image.Image], List[np.ndarray]], - interaction: Optional[Union[str, List[str]]] = None, + image_path: Union[str, Image.Image, np.ndarray, List[str], List[Image.Image], List[np.ndarray]], + interactions: Optional[Union[str, List[str]]] = None, + frames_per_interaction: int = 10, size: Optional[int] = None, vis_threshold: float = 1.5, output_dir: str = "./cut3r_output", @@ -825,7 +681,7 @@ def run_two_stage_3dgs_video( os.makedirs(output_dir, exist_ok=True) recon_info = self.reconstruct_ply( - data_path, + image_path, ply_path=output_dir, size=size, vis_threshold=vis_threshold, @@ -843,12 +699,18 @@ def run_two_stage_3dgs_video( output_video_path = os.path.join(output_dir, output_name) - if isinstance(interaction, list) and len(interaction) > 0: + interaction_sequence = self.operator.normalize_interaction_sequence(interactions) + if interaction_sequence and frames_per_interaction > 1: + interaction_sequence = [ + a for a in interaction_sequence for _ in range(frames_per_interaction) + ] + + if interaction_sequence: self.render_interaction_video_with_3dgs( ply_path=ply_path, camera_range=camera_range, base_camera_config=base_camera_config, - interaction_sequence=interaction, + interaction_sequence=interaction_sequence, image_width=image_width, image_height=image_height, output_path=output_video_path, @@ -866,22 +728,33 @@ def run_two_stage_3dgs_video( def __call__( self, - input_: Union[str, Image.Image, np.ndarray, List[str], List[Image.Image], List[np.ndarray]], - interaction: Optional[Union[str, Dict[str, Any]]] = None, - **kwargs - ) -> CUT3RResult: + image_path: Optional[Union[str, List[str]]] = None, + images: Any = None, + interactions: Optional[Union[str, List[str]]] = None, + task_type: Optional[str] = None, + **kwargs, + ) -> Union[CUT3RResult, str]: """ - Main call interface for the pipeline. - - Args: - input_: Input image(s) - interaction: Interaction string or dictionary - **kwargs: Additional arguments - - Returns: - CUT3RResult object containing processed results as PIL Images or video frame list + Main call interface. + - Base mode (task_type is None or "cut3r_base"): direct CUT3R representation/process. + - "cut3r_two_stage_3dgs": two-stage reconstruction + 3DGS video (returns output_video_path: str). """ - return self.process(input_, interaction, **kwargs) + data = images if images is not None else image_path + if data is None: + raise ValueError("Provide image_path or images.") + + if task_type == "cut3r_two_stage_3dgs": + return self.run_two_stage_3dgs_video( + image_path=data, + interactions=interactions, + **kwargs, + ) + + return self.process( + input_=data, + interaction=interactions, + **kwargs, + ) def stream( self, diff --git a/src/openworldlib/representations/point_clouds_generation/cut3r/cut3r_representation.py b/src/openworldlib/representations/point_clouds_generation/cut3r/cut3r_representation.py index dfc746b0..77bb4dfb 100644 --- a/src/openworldlib/representations/point_clouds_generation/cut3r/cut3r_representation.py +++ b/src/openworldlib/representations/point_clouds_generation/cut3r/cut3r_representation.py @@ -7,6 +7,14 @@ from huggingface_hub import snapshot_download +from ...base_representation import BaseRepresentation +from ....base_models.three_dimensions.point_clouds.gaussian_splatting.scene.dataset_readers import ( + fetchPly, +) +from ....representations.point_clouds_generation.flash_world.flash_world.render import ( + gaussian_render, +) + # Try to import gdown for Google Drive downloads try: import gdown @@ -41,7 +49,7 @@ } -class CUT3RRepresentation: +class CUT3RRepresentation(BaseRepresentation): """ Representation for CUT3R 3D scene reconstruction. """ @@ -54,6 +62,7 @@ def __init__(self, model: Optional[ARCroco3DStereo] = None, device: Optional[str model: Pre-loaded ARCroco3DStereo model (optional) device: Device to run on ('cuda' or 'cpu') """ + super().__init__() self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.model = model @@ -233,6 +242,12 @@ def _prepare_views( image_paths.append(temp_file.name) temp_files.append(temp_file.name) + # CUT3R recurrent inference expects at least two views. + # If only one image is provided, duplicate it as a minimal fallback + # to avoid invalid position/index errors inside RoPE attention. + if len(image_paths) == 1: + image_paths.append(image_paths[0]) + # Load images using CUT3R's loader loaded_images = load_images(image_paths, size=self.size, verbose=False) @@ -246,6 +261,18 @@ def _prepare_views( # Convert to views format views = [] for i, img_data in enumerate(loaded_images): + # Ensure true_shape is in the correct format: (batch_size, 2) + # load_images returns (1, 2), we need to expand to (batch_size, 2) + batch_size = img_data["img"].shape[0] + true_shape_np = img_data["true_shape"] # Shape: (1, 2) + true_shape_tensor = torch.from_numpy(true_shape_np) # Shape: (1, 2) + # Expand to (batch_size, 2) if needed + if true_shape_tensor.shape[0] == 1 and batch_size > 1: + true_shape_tensor = true_shape_tensor.repeat(batch_size, 1) + elif true_shape_tensor.shape[0] != batch_size: + # If shape doesn't match, use the first row and repeat + true_shape_tensor = true_shape_tensor[0:1].repeat(batch_size, 1) + view = { "img": img_data["img"], "ray_map": torch.full( @@ -257,7 +284,7 @@ def _prepare_views( ), torch.nan, ), - "true_shape": torch.from_numpy(img_data["true_shape"]), + "true_shape": true_shape_tensor, "idx": i, "instance": str(i), "camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze(0), @@ -381,3 +408,315 @@ def get_representation(self, data: Dict[str, Any]) -> Dict[str, Any]: return results + @staticmethod + def _preprocess_point_cloud_for_render( + points: np.ndarray, + colors: np.ndarray, + scene_center: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: + """ + Light-weight cleanup to make rendering closer to CUT3R visualization: + 1) remove invalid rows + 2) trim far outliers + 3) voxel downsample to reduce overdraw blur + """ + valid_mask = np.isfinite(points).all(axis=1) & np.isfinite(colors).all(axis=1) + points = points[valid_mask] + colors = colors[valid_mask] + if len(points) == 0: + return points, colors + + # Trim extreme outliers by distance-to-center (keeps dense core). + d = np.linalg.norm(points - scene_center[None, :], axis=1) + d_thr = np.quantile(d, 0.995) + keep = d <= d_thr + points = points[keep] + colors = colors[keep] + if len(points) == 0: + return points, colors + + scene_radius = float(np.linalg.norm(points - scene_center[None, :], axis=1).max() + 1e-8) + voxel_size = max(scene_radius / 512.0, 1e-4) + + # Voxel downsample (first-point per voxel, deterministic). + voxel_coords = np.floor(points / voxel_size).astype(np.int64) + _, unique_idx = np.unique(voxel_coords, axis=0, return_index=True) + unique_idx = np.sort(unique_idx) + points = points[unique_idx] + colors = colors[unique_idx] + + return points, colors + + @staticmethod + def _estimate_gaussian_scale(points: np.ndarray, scene_center: np.ndarray) -> float: + """ + Estimate a conservative Gaussian scale from local spacing. + Large scales are the main reason for "foggy/blurry" outputs. + """ + if len(points) < 4: + scene_radius = float(np.linalg.norm(points - scene_center[None, :], axis=1).max() + 1e-8) + return max(scene_radius / 2000.0, 1e-4) + + sample_n = min(len(points), 2048) + rng = np.random.default_rng(42) + idx = rng.choice(len(points), size=sample_n, replace=False) + sample = torch.from_numpy(points[idx]).float() + # Pairwise distances on a small sample for robust nearest-neighbor spacing. + dist = torch.cdist(sample, sample, p=2) + dist.fill_diagonal_(1e9) + nn = dist.min(dim=1).values + nn_med = float(nn.median().item()) + + scene_radius = float(np.linalg.norm(points - scene_center[None, :], axis=1).max() + 1e-8) + min_scale = max(scene_radius / 5000.0, 1e-4) + max_scale = max(scene_radius / 300.0, min_scale) + return float(np.clip(nn_med * 0.6, min_scale, max_scale)) + + def render_with_3dgs( + self, + ply_path: str, + camera_config: Dict[str, Any], + image_width: int = 704, + image_height: int = 480, + device: Optional[str] = None, + near_plane: float = 0.01, + far_plane: float = 1000.0, + ) -> Image.Image: + """ + Render a single frame from a CUT3R reconstruction using 3D Gaussian Splatting. + Follow the same robust rendering strategy as VGGT: + - normalize scene for scale invariance + - dynamic near/far planes + - multiple camera convention probing + - deterministic point-projection fallback when gsplat fails + """ + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + pcd = fetchPly(ply_path) + points = np.asarray(pcd.points, dtype=np.float32) + colors = np.asarray(pcd.colors, dtype=np.float32) + + if points.size == 0: + raise RuntimeError(f"No points loaded from PLY: {ply_path}") + + # Estimate global scene scale and normalize to stabilize rendering. + scene_center = points.mean(axis=0) + scene_radius = float(np.linalg.norm(points - scene_center[None, :], axis=1).max() + 1e-8) + scene_radius = max(scene_radius, 1e-6) + + points_norm = (points - scene_center[None, :]) / scene_radius + center = np.asarray(camera_config.get("center", scene_center.tolist()), dtype=np.float32) + center_norm = (center - scene_center) / scene_radius + + radius_raw = float(camera_config.get("radius", 1.0 * scene_radius)) + radius_norm = max(radius_raw / scene_radius, 1e-3) + # +180° yaw so camera is on the opposite side (scene faces camera, not back). + yaw_deg = float(camera_config.get("yaw", 0.0)) + 180.0 + pitch_deg = float(camera_config.get("pitch", 0.0)) + + yaw = np.deg2rad(yaw_deg) + pitch = np.deg2rad(pitch_deg) + cam_x = center_norm[0] + radius_norm * np.cos(pitch) * np.sin(yaw) + cam_y = center_norm[1] + radius_norm * np.sin(pitch) + cam_z = center_norm[2] + radius_norm * np.cos(pitch) * np.cos(yaw) + cam_pos = np.array([cam_x, cam_y, cam_z], dtype=np.float32) + + def build_c2w( + look_at: np.ndarray, + eye: np.ndarray, + reverse_forward: bool = False, + basis_layout: str = "row", + ) -> np.ndarray: + forward = (eye - look_at) if reverse_forward else (look_at - eye) + forward = forward / (np.linalg.norm(forward) + 1e-8) + up = np.array([0.0, 1.0, 0.0], dtype=np.float32) + right = np.cross(forward, up) + right_norm = np.linalg.norm(right) + if right_norm < 1e-6: + up = np.array([0.0, 0.0, 1.0], dtype=np.float32) + right = np.cross(forward, up) + right_norm = np.linalg.norm(right) + right = right / (right_norm + 1e-8) + up = np.cross(right, forward) + up = up / (np.linalg.norm(up) + 1e-8) + + c2w_local = np.eye(4, dtype=np.float32) + if basis_layout == "row": + c2w_local[0, :3] = right + c2w_local[1, :3] = up + c2w_local[2, :3] = forward + else: + c2w_local[:3, 0] = right + c2w_local[:3, 1] = up + c2w_local[:3, 2] = forward + c2w_local[:3, 3] = eye + return c2w_local + + fx = 0.5 * image_width / np.tan(np.deg2rad(60.0) / 2.0) + fy = 0.5 * image_height / np.tan(np.deg2rad(45.0) / 2.0) + cx = image_width / 2.0 + cy = image_height / 2.0 + + xyz = torch.from_numpy(points_norm).to(device=device, dtype=torch.float32) + scale_value = self._estimate_gaussian_scale(points_norm, center_norm) + scale = torch.full((xyz.shape[0], 3), scale_value, device=device, dtype=torch.float32) + rotation = torch.zeros((xyz.shape[0], 4), device=device, dtype=torch.float32) + rotation[:, 0] = 1.0 + opacity = torch.full((xyz.shape[0], 1), 0.95, device=device, dtype=torch.float32) + color_tensor = torch.from_numpy(np.clip(colors, 0.0, 1.0)).to(device=device, dtype=torch.float32) + + gaussian_params = torch.cat([xyz, opacity, scale, rotation, color_tensor], dim=-1).unsqueeze(0) + intr = torch.tensor([[fx, fy, cx, cy]], dtype=torch.float32, device=device).unsqueeze(0) + + # Dynamic planes are more robust for arbitrary world scales. + near_dynamic = max(near_plane, radius_norm * 0.01) + far_dynamic = max(far_plane, radius_norm * 20.0) + + if not hasattr(self, "_render_variant_cache"): + self._render_variant_cache = {} + + def render_candidate(reverse_forward: bool, basis_layout: str): + c2w_local = build_c2w( + look_at=center_norm, + eye=cam_pos, + reverse_forward=reverse_forward, + basis_layout=basis_layout, + ) + test_c2ws_local = torch.from_numpy(c2w_local).unsqueeze(0).unsqueeze(0).to( + device=device, dtype=torch.float32 + ) + rgb_local, _ = gaussian_render( + gaussian_params, + test_c2ws_local, + intr, + image_width, + image_height, + near_plane=near_dynamic, + far_plane=far_dynamic, + use_checkpoint=False, + sh_degree=0, + bg_mode="black", + ) + rgb_img_local = rgb_local[0, 0].clamp(-1.0, 1.0).add(1.0).div(2.0) + gray = rgb_img_local.mean(dim=0) + non_bg_ratio = float((gray > 0.03).float().mean().item()) + std_v = float(rgb_img_local.std().item()) + score = non_bg_ratio + 0.5 * std_v + return rgb_img_local, score, non_bg_ratio + + cached_variant = self._render_variant_cache.get( + ply_path, + {"reverse_forward": False, "basis_layout": "row"}, + ) + rgb_img, best_score, best_non_bg_ratio = render_candidate( + reverse_forward=bool(cached_variant["reverse_forward"]), + basis_layout=str(cached_variant["basis_layout"]), + ) + + # If the cached/default pose is too empty, probe multiple camera conventions once. + if best_score < 0.03 or best_non_bg_ratio < 0.001: + candidates = [ + {"reverse_forward": False, "basis_layout": "row"}, + {"reverse_forward": True, "basis_layout": "row"}, + {"reverse_forward": False, "basis_layout": "col"}, + {"reverse_forward": True, "basis_layout": "col"}, + ] + best_variant = cached_variant + for cand in candidates: + rgb_try, score_try, non_bg_try = render_candidate( + reverse_forward=bool(cand["reverse_forward"]), + basis_layout=str(cand["basis_layout"]), + ) + if score_try > best_score: + rgb_img = rgb_try + best_score = score_try + best_non_bg_ratio = non_bg_try + best_variant = cand + self._render_variant_cache[ply_path] = best_variant + + # If gsplat still fails (near-empty), fallback to deterministic point projection. + if best_score < 0.03 or best_non_bg_ratio < 0.001: + best_variant = self._render_variant_cache.get( + ply_path, + {"reverse_forward": False, "basis_layout": "row"}, + ) + c2w_best = build_c2w( + look_at=center_norm, + eye=cam_pos, + reverse_forward=bool(best_variant["reverse_forward"]), + basis_layout=str(best_variant["basis_layout"]), + ) + + img_fallback = np.zeros((image_height, image_width, 3), dtype=np.float32) + depth_buf = np.full((image_height, image_width), np.inf, dtype=np.float32) + + max_points = 300000 + if points_norm.shape[0] > max_points: + rng = np.random.default_rng(42) + keep_idx = rng.choice(points_norm.shape[0], size=max_points, replace=False) + proj_points = points_norm[keep_idx] + proj_colors = colors[keep_idx] + else: + proj_points = points_norm + proj_colors = colors + + w2c = np.linalg.inv(c2w_best).astype(np.float32) + pts_h = np.concatenate( + [proj_points, np.ones((proj_points.shape[0], 1), dtype=np.float32)], + axis=1, + ) + cam_pts = (w2c @ pts_h.T).T[:, :3] + + best_proj_count = -1 + best_proj_payload = None + for depth_sign in [1.0, -1.0]: + z = cam_pts[:, 2] * depth_sign + valid_z = z > 1e-4 + cam_pts_s = cam_pts[valid_z] + z_s = z[valid_z] + c_s = proj_colors[valid_z] + if cam_pts_s.shape[0] == 0: + continue + + u = (fx * (cam_pts_s[:, 0] / z_s) + cx).astype(np.int32) + v = (fy * (cam_pts_s[:, 1] / z_s) + cy).astype(np.int32) + in_view = (u >= 0) & (u < image_width) & (v >= 0) & (v < image_height) + view_count = int(in_view.sum()) + if view_count > best_proj_count: + best_proj_count = view_count + best_proj_payload = (u[in_view], v[in_view], z_s[in_view], c_s[in_view]) + + if best_proj_payload is not None and best_proj_count > 0: + u, v, z, c_proj = best_proj_payload + order = np.argsort(z) + u = u[order] + v = v[order] + z = z[order] + c_proj = c_proj[order] + + for uu, vv, zz, cc in zip(u, v, z, c_proj): + if zz < depth_buf[vv, uu]: + depth_buf[vv, uu] = zz + img_fallback[vv, uu] = np.clip(cc, 0.0, 1.0) + + valid_mask = np.isfinite(depth_buf).astype(np.uint8) + if valid_mask.any(): + kernel = np.ones((3, 3), np.uint8) + dilated = cv2.dilate((img_fallback * 255).astype(np.uint8), kernel, iterations=1) + filled = cv2.dilate(valid_mask, kernel, iterations=1) + img_fallback[filled > 0] = dilated[filled > 0] / 255.0 + + rgb_img = torch.from_numpy(img_fallback).permute(2, 0, 1).to(torch.float32) + + rgb_np = ( + rgb_img.mul(255.0) + .permute(1, 2, 0) + .detach() + .cpu() + .numpy() + .astype(np.uint8) + ) + # Orientation fix: vertical flip only (correct upside-down). + rgb_np = np.flipud(rgb_np) + return Image.fromarray(rgb_np) + diff --git a/test/test_cut3r.py b/test/test_cut3r.py index 4be139d9..ff670747 100644 --- a/test/test_cut3r.py +++ b/test/test_cut3r.py @@ -6,7 +6,7 @@ from openworldlib.pipelines.cut3r.pipeline_cut3r import CUT3RPipeline -DATA_PATH = "./data/test_case/test_image_seq_case1" +DATA_PATH = "./data/test_case/test_image_case1/ref_image.png" MODEL_NAME = "cut3r_224_linear_4" # or "cut3r_512_dpt_4_64" SIZE = 224 @@ -15,8 +15,8 @@ # Interaction sequence for camera control in the second stage. # Keep None to use a default orbit; or set to a list like: -# ["move_left", "move_right", "zoom_in"]. -INTERACTION = "move_left" +# ["forward", "camera_l"]. +INTERACTIONS = ["forward", "camera_l"] # Two-stage camera config for 3DGS rendering. CAMERA_RADIUS = 4.0 @@ -32,9 +32,10 @@ size=SIZE, ) -output_video_path = pipeline.run_two_stage_3dgs_video( - data_path=DATA_PATH, - interaction=INTERACTION, +output_video_path = pipeline( + image_path=DATA_PATH, + interactions=INTERACTIONS, + task_type="cut3r_two_stage_3dgs", size=SIZE, vis_threshold=VIS_THRESHOLD, output_dir=OUTPUT_DIR,