diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index 475cc5c388..359334b25e 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -205,7 +205,7 @@ def populate_modules(self): # We can have colors without points. and self.seed_points[1].shape[0] > 0 ): - shs = torch.zeros((self.seed_points[1].shape[0], dim_sh, 3)).float().cuda() + shs = torch.zeros((self.seed_points[1].shape[0], dim_sh, 3)).float().to(self.device) if self.config.sh_degree > 0: shs[:, 0, :3] = RGB2SH(self.seed_points[1] / 255) shs[:, 1:, 3:] = 0.0 @@ -532,7 +532,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: camera_scale_fac = self._get_downscale_factor() camera.rescale_output_resolution(1 / camera_scale_fac) viewmat = get_viewmat(optimized_camera_to_world) - K = camera.get_intrinsics_matrices().cuda() + K = camera.get_intrinsics_matrices().to(self.device) W, H = int(camera.width.item()), int(camera.height.item()) self.last_size = (H, W) camera.rescale_output_resolution(camera_scale_fac) # type: ignore