diff --git a/README.md b/README.md index 27b89ba71..da78091ed 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +This is a fork created by Tail-19, aiming to make this repository work on M1 chip. The project has finished 🎉 + # NeRF-pytorch diff --git a/requirements.txt b/requirements.txt index 168d1b855..f3e3337e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch==1.11.0 +torch>=1.11.0 torchvision>=0.9.1 imageio imageio-ffmpeg diff --git a/run_nerf.py b/run_nerf.py index bc270be86..c81ac3eb4 100644 --- a/run_nerf.py +++ b/run_nerf.py @@ -19,7 +19,7 @@ from load_LINEMOD import load_LINEMOD_data -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.device_count() > 0 else "cpu") np.random.seed(0) DEBUG = False @@ -275,7 +275,9 @@ def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=F raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists) dists = z_vals[...,1:] - z_vals[...,:-1] - dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples] + # print(type(dists)) + # print(dists.device) + dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape).to("mps")], -1) # [N_rays, N_samples] dists = dists * torch.norm(rays_d[...,None,:], dim=-1) @@ -765,7 +767,7 @@ def train(): img_loss = img2mse(rgb, target_s) trans = extras['raw'][...,-1] loss = img_loss - psnr = mse2psnr(img_loss) + psnr = mse2psnr(img_loss) if 'rgb0' in extras: img_loss0 = img2mse(extras['rgb0'], target_s) @@ -873,6 +875,9 @@ def train(): if __name__=='__main__': - torch.set_default_tensor_type('torch.cuda.FloatTensor') - + # torch.set_default_tensor_type('torch.mps.FloatTensor') + torch.set_default_tensor_type('torch.FloatTensor') + + torch.set_default_dtype(torch.float32) + torch.set_default_device(torch.device("mps")) train() diff --git a/run_nerf_helpers.py b/run_nerf_helpers.py index bc6ee779d..e7cb73836 100644 --- a/run_nerf_helpers.py +++ b/run_nerf_helpers.py @@ -7,7 +7,7 @@ # Misc img2mse = lambda x, y : torch.mean((x - y) ** 2) -mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) +mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]).to(x.device)) to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)