From ead21bad6ff3086b7f0234c913a3bbae48b8c2b7 Mon Sep 17 00:00:00 2001 From: Chris Murray <59452295+ChrisMOxon@users.noreply.github.com> Date: Tue, 24 Mar 2026 01:18:28 +0000 Subject: [PATCH] fix(splatfacto): replace hardcoded .cuda() with .to(self.device) for MPS support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two lines in splatfacto.py use .cuda() directly instead of the model's device property, causing crashes on non-CUDA systems (Apple Silicon MPS, CPU-only machines) with "Torch not compiled with CUDA enabled". Replace with .to(self.device) which correctly resolves to cuda:N on NVIDIA, mps:0 on Apple Silicon, or cpu as appropriate. Tested on M2 Max (MPS) — model initializes and runs forward pass successfully. CUDA compatibility preserved: .to(self.device) resolves to cuda:N on CUDA systems. Discovered and benchmarked by an autonomous Claude Code agent ("Ralph") optimizing a 3D scanning pipeline for Apple Silicon. Co-Authored-By: Claude Opus 4.6 (1M context) --- nerfstudio/models/splatfacto.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index 475cc5c388..359334b25e 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -205,7 +205,7 @@ def populate_modules(self): # We can have colors without points. and self.seed_points[1].shape[0] > 0 ): - shs = torch.zeros((self.seed_points[1].shape[0], dim_sh, 3)).float().cuda() + shs = torch.zeros((self.seed_points[1].shape[0], dim_sh, 3)).float().to(self.device) if self.config.sh_degree > 0: shs[:, 0, :3] = RGB2SH(self.seed_points[1] / 255) shs[:, 1:, 3:] = 0.0 @@ -532,7 +532,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: camera_scale_fac = self._get_downscale_factor() camera.rescale_output_resolution(1 / camera_scale_fac) viewmat = get_viewmat(optimized_camera_to_world) - K = camera.get_intrinsics_matrices().cuda() + K = camera.get_intrinsics_matrices().to(self.device) W, H = int(camera.width.item()), int(camera.height.item()) self.last_size = (H, W) camera.rescale_output_resolution(camera_scale_fac) # type: ignore