diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 000000000..075d62190 Binary files /dev/null and b/.DS_Store differ diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..bd3c31c31 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,50 @@ +cmake_minimum_required(VERSION 3.10) +project(nerf_cpp) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Find required packages +find_package(Torch REQUIRED) +find_package(OpenCV REQUIRED) +find_package(Eigen3 REQUIRED) +find_package(nlohmann_json REQUIRED) + +# Include directories +include_directories( + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${TORCH_INCLUDE_DIRS} + ${OpenCV_INCLUDE_DIRS} + ${EIGEN3_INCLUDE_DIR} + ${nlohmann_json_INCLUDE_DIRS} +) + +# Add source files +file(GLOB_RECURSE SOURCES + "src/*.cpp" +) + +# Create executable +add_executable(nerf_train examples/train.cpp ${SOURCES}) + +# Link libraries +target_link_libraries(nerf_train + ${TORCH_LIBRARIES} + ${OpenCV_LIBS} + Eigen3::Eigen + nlohmann_json::nlohmann_json +) + +# Set compiler flags +if(MSVC) + target_compile_options(nerf_train PRIVATE /W4) +else() + target_compile_options(nerf_train PRIVATE -Wall -Wextra -Wpedantic) +endif() + +# Set CUDA properties if available +if(TORCH_CUDA_AVAILABLE) + set_target_properties(nerf_train PROPERTIES + CUDA_SEPARABLE_COMPILATION ON + ) +endif() \ No newline at end of file diff --git a/README.md b/README.md index 27b89ba71..a420b5676 100644 --- a/README.md +++ b/README.md @@ -161,3 +161,194 @@ However, if you find this implementation or pre-trained models helpful, please c year={2020} } ``` + +# NeRF C++ Implementation + +This is a C++ implementation of [NeRF](http://www.matthewtancik.com/nerf) (Neural Radiance Fields), a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. + +## Dependencies + +You can install dependencies either through Conda or system package manager. + +### Using Conda (Recommended) + +1. Install Miniconda or Anaconda if you haven't already. + +2. Create and activate a new conda environment: +```bash +# Create new environment +conda create -n nerf-cpp python=3.8 +conda activate nerf-cpp + +# Install dependencies +conda install -c pytorch pytorch torchvision cudatoolkit=11.8 +conda install -c conda-forge opencv eigen nlohmann_json +conda install -c conda-forge cmake ninja +``` + +3. Set environment variables for CMake: +```bash +export CMAKE_PREFIX_PATH=$CONDA_PREFIX +export Torch_DIR=$CONDA_PREFIX/lib/python3.8/site-packages/torch/share/cmake/Torch +``` + +### Using System Package Manager (Alternative) + +#### Ubuntu/Debian + +```bash +# Install system dependencies +sudo apt-get update +sudo apt-get install -y \ + build-essential \ + cmake \ + git \ + libopencv-dev \ + libeigen3-dev \ + nlohmann-json3-dev + +# Install LibTorch (PyTorch C++ API) +wget https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0%2Bcu118.zip +unzip libtorch-cxx11-abi-shared-with-deps-2.1.0+cu118.zip +sudo mv libtorch /usr/local/ +``` + +#### macOS + +```bash +# Install system dependencies using Homebrew +brew install cmake opencv eigen nlohmann-json + +# Install LibTorch +wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-2.1.0.zip +unzip libtorch-macos-2.1.0.zip +sudo mv libtorch /usr/local/ +``` + +## Building the Project + +1. Clone the repository: +```bash +git clone https://github.com/yourusername/nerf-cpp.git +cd nerf-cpp +``` + +2. Create a build directory and build the project: +```bash +mkdir build && cd build + +# If using Conda +cmake -DCMAKE_PREFIX_PATH=$CONDA_PREFIX -DTorch_DIR=$CONDA_PREFIX/lib/python3.8/site-packages/torch/share/cmake/Torch .. + +# If using system packages +cmake .. + +make -j$(nproc) +``` + +## Running the Project + +### Download Example Data + +First, download the example datasets: + +```bash +bash download_example_data.sh +``` + +This will download the `lego` and `fern` datasets to the `data` directory. + +### Training + +To train a NeRF model on the lego dataset: + +```bash +./nerf_train configs/lego.txt +``` + +The training process will: +1. Load the dataset from `data/nerf_synthetic/lego` +2. Train the model for 100,000 iterations +3. Save checkpoints every 1,000 iterations +4. Save the final model as `final_model.pt` + +### Configuration + +The configuration files in the `configs` directory control various aspects of training: + +- `expname`: Name of the experiment +- `basedir`: Directory to save logs and checkpoints +- `datadir`: Directory containing the dataset +- `dataset_type`: Type of dataset ("blender" or "llff") +- `N_samples`: Number of samples per ray +- `N_importance`: Number of importance samples +- `use_viewdirs`: Whether to use view-dependent effects +- `white_bkgd`: Whether to use white background + +## Project Structure + +``` +nerf-cpp/ +├── CMakeLists.txt +├── include/ +│ └── nerf/ +│ ├── model.hpp +│ ├── renderer.hpp +│ └── dataset.hpp +├── src/ +│ ├── model.cpp +│ ├── renderer.cpp +│ └── dataset.cpp +├── examples/ +│ └── train.cpp +├── configs/ +│ ├── lego.txt +│ └── fern.txt +└── data/ + ├── nerf_synthetic/ + └── nerf_llff_data/ +``` + +## Troubleshooting + +### Common Issues + +1. **CUDA not found** + - Make sure you have CUDA installed + - If using Conda, make sure you installed the correct CUDA toolkit version + - Set `TORCH_CUDA_VERSION` in CMake if needed + +2. **OpenCV not found** + - If using Conda, make sure you activated the environment + - If using system packages, install OpenCV development packages + - Set `OpenCV_DIR` in CMake if needed + +3. **Eigen3 not found** + - If using Conda, make sure you activated the environment + - If using system packages, install Eigen3 development packages + - Set `EIGEN3_INCLUDE_DIR` in CMake if needed + +### Memory Usage + +The default configuration uses: +- Batch size: 1024 rays +- Samples per ray: 64 +- Network width: 256 +- Network depth: 8 + +You can adjust these parameters in the configuration files to reduce memory usage if needed. + +## Citation + +If you find this implementation helpful, please consider citing: + +```bibtex +@misc{mildenhall2020nerf, + title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis}, + author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng}, + year={2020}, + eprint={2003.08934}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` diff --git a/configs/lego.txt b/configs/lego.txt index 2852ee318..4fb4bdc3f 100644 --- a/configs/lego.txt +++ b/configs/lego.txt @@ -7,13 +7,23 @@ no_batching = True use_viewdirs = True white_bkgd = True -lrate_decay = 500 +# 保持渲染質量的關鍵參數 N_samples = 64 N_importance = 128 -N_rand = 1024 +multires = 10 +multires_views = 4 +netdepth = 8 +netwidth = 256 + +# 優化速度的參數 +N_rand = 2048 +chunk = 32768 +netchunk = 65536 + +# 學習率策略優化 +lrate = 1e-3 +lrate_decay = 250 precrop_iters = 500 precrop_frac = 0.5 - -half_res = True diff --git a/examples/train.cpp b/examples/train.cpp new file mode 100644 index 000000000..66bb78f6b --- /dev/null +++ b/examples/train.cpp @@ -0,0 +1,109 @@ +#include "nerf/model.hpp" +#include "nerf/renderer.hpp" +#include "nerf/dataset.hpp" +#include +#include +#include +#include + +using namespace nerf; + +int main(int argc, char* argv[]) { + try { + // Parse command line arguments + if (argc < 2) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return 1; + } + + // Load configuration + std::string config_file = argv[1]; + // TODO: Implement config loading + + // Set device + torch::Device device(torch::kCUDA); + + // Create model + auto model = std::make_shared( + 8, // netdepth + 256, // netwidth + 8, // netdepth_fine + 256, // netwidth_fine + 10, // multires + 4, // multires_views + true // use_viewdirs + ); + model->to(device); + + // Create renderer + auto renderer = std::make_shared( + model, + 64, // N_samples + 64, // N_importance + true, // use_viewdirs + 1.0f, // raw_noise_std + true // white_bkgd + ); + + // Create dataset + auto dataset = std::make_shared( + "./data/nerf_synthetic/lego", // datadir + "blender", // dataset_type + 8, // factor + true, // use_viewdirs + true // white_bkgd + ); + + // Create optimizer + torch::optim::Adam optimizer( + model->parameters(), + torch::optim::AdamOptions(1e-3) + ); + + // Training loop + int num_epochs = 100000; + int batch_size = 1024; + + for (int epoch = 0; epoch < num_epochs; ++epoch) { + // Get batch of rays + auto [rays_o, rays_d, target_rgb] = dataset->get_data(); + rays_o = rays_o.to(device); + rays_d = rays_d.to(device); + target_rgb = target_rgb.to(device); + + // Forward pass + auto [rgb_map, depth_map, acc_map, _] = renderer->render_rays( + rays_o, rays_d, rays_d, + dataset->get_near().to(device), + dataset->get_far().to(device) + ); + + // Compute loss + auto loss = torch::mse_loss(rgb_map, target_rgb); + + // Backward pass + optimizer.zero_grad(); + loss.backward(); + optimizer.step(); + + // Print progress + if (epoch % 100 == 0) { + std::cout << "Epoch " << epoch << ", Loss: " << loss.item() << std::endl; + } + + // Save checkpoint + if (epoch % 1000 == 0) { + torch::save(model, "checkpoint_" + std::to_string(epoch) + ".pt"); + } + } + + // Save final model + torch::save(model, "final_model.pt"); + + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + return 1; + } + + return 0; +} \ No newline at end of file diff --git a/include/nerf/dataset.hpp b/include/nerf/dataset.hpp new file mode 100644 index 000000000..11dfe6543 --- /dev/null +++ b/include/nerf/dataset.hpp @@ -0,0 +1,48 @@ +#pragma once + +#include +#include +#include +#include + +namespace nerf { + +class Dataset { +public: + Dataset(const std::string& datadir, + const std::string& dataset_type, + int factor = 8, + bool use_viewdirs = true, + bool white_bkgd = false); + + std::tuple + get_data(); + + int get_H() const { return H_; } + int get_W() const { return W_; } + int get_K() const { return K_; } + int get_focal() const { return focal_; } + torch::Tensor get_near() const { return near_; } + torch::Tensor get_far() const { return far_; } + +private: + std::string datadir_; + std::string dataset_type_; + int factor_; + bool use_viewdirs_; + bool white_bkgd_; + + int H_; + int W_; + int K_; + float focal_; + torch::Tensor near_; + torch::Tensor far_; + + void load_llff_data(); + void load_blender_data(); + torch::Tensor load_image(const std::string& path); + torch::Tensor load_poses(const std::string& path); +}; + +} // namespace nerf \ No newline at end of file diff --git a/include/nerf/model.hpp b/include/nerf/model.hpp new file mode 100644 index 000000000..0eed45ba6 --- /dev/null +++ b/include/nerf/model.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include + +namespace nerf { + +class NeRFModel : public torch::nn::Module { +public: + NeRFModel(int netdepth = 8, int netwidth = 256, int netdepth_fine = 8, int netwidth_fine = 256, + int multires = 10, int multires_views = 4, bool use_viewdirs = true); + + std::tuple forward( + const torch::Tensor& inputs_flat, + const torch::Tensor& viewdirs = torch::Tensor(), + bool is_fine = false); + + std::tuple get_outputs( + const torch::Tensor& inputs_flat, + const torch::Tensor& viewdirs = torch::Tensor(), + bool is_fine = false); + +private: + int netdepth_; + int netwidth_; + int netdepth_fine_; + int netwidth_fine_; + int multires_; + int multires_views_; + bool use_viewdirs_; + + torch::nn::Sequential net_{nullptr}; + torch::nn::Sequential net_fine_{nullptr}; + torch::nn::Linear viewdirs_net_{nullptr}; + torch::nn::Linear viewdirs_net_fine_{nullptr}; + + torch::Tensor embed_fn(const torch::Tensor& inputs); + torch::Tensor embeddirs_fn(const torch::Tensor& inputs); +}; + +} // namespace nerf \ No newline at end of file diff --git a/include/nerf/renderer.hpp b/include/nerf/renderer.hpp new file mode 100644 index 000000000..197974483 --- /dev/null +++ b/include/nerf/renderer.hpp @@ -0,0 +1,48 @@ +#pragma once + +#include "model.hpp" +#include +#include +#include + +namespace nerf { + +class Renderer { +public: + Renderer(std::shared_ptr model, + int N_samples = 64, + int N_importance = 64, + bool use_viewdirs = true, + float raw_noise_std = 0.0f, + bool white_bkgd = false); + + std::tuple render_rays( + const torch::Tensor& rays_o, + const torch::Tensor& rays_d, + const torch::Tensor& viewdirs, + const torch::Tensor& near, + const torch::Tensor& far, + bool is_fine = false); + + torch::Tensor render( + const torch::Tensor& H, + const torch::Tensor& W, + const torch::Tensor& K, + const torch::Tensor& c2w, + const torch::Tensor& near, + const torch::Tensor& far, + bool is_fine = false); + +private: + std::shared_ptr model_; + int N_samples_; + int N_importance_; + bool use_viewdirs_; + float raw_noise_std_; + bool white_bkgd_; + + torch::Tensor sample_pdf(const torch::Tensor& bins, const torch::Tensor& weights, int N_samples); + torch::Tensor compute_accumulated_transmittance(const torch::Tensor& alphas); +}; + +} // namespace nerf \ No newline at end of file diff --git a/load_LINEMOD.py b/load_LINEMOD.py deleted file mode 100644 index 388fdbbc4..000000000 --- a/load_LINEMOD.py +++ /dev/null @@ -1,95 +0,0 @@ -import os -import torch -import numpy as np -import imageio -import json -import torch.nn.functional as F -import cv2 - - -trans_t = lambda t : torch.Tensor([ - [1,0,0,0], - [0,1,0,0], - [0,0,1,t], - [0,0,0,1]]).float() - -rot_phi = lambda phi : torch.Tensor([ - [1,0,0,0], - [0,np.cos(phi),-np.sin(phi),0], - [0,np.sin(phi), np.cos(phi),0], - [0,0,0,1]]).float() - -rot_theta = lambda th : torch.Tensor([ - [np.cos(th),0,-np.sin(th),0], - [0,1,0,0], - [np.sin(th),0, np.cos(th),0], - [0,0,0,1]]).float() - - -def pose_spherical(theta, phi, radius): - c2w = trans_t(radius) - c2w = rot_phi(phi/180.*np.pi) @ c2w - c2w = rot_theta(theta/180.*np.pi) @ c2w - c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w - return c2w - - -def load_LINEMOD_data(basedir, half_res=False, testskip=1): - splits = ['train', 'val', 'test'] - metas = {} - for s in splits: - with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: - metas[s] = json.load(fp) - - all_imgs = [] - all_poses = [] - counts = [0] - for s in splits: - meta = metas[s] - imgs = [] - poses = [] - if s=='train' or testskip==0: - skip = 1 - else: - skip = testskip - - for idx_test, frame in enumerate(meta['frames'][::skip]): - fname = frame['file_path'] - if s == 'test': - print(f"{idx_test}th test frame: {fname}") - imgs.append(imageio.imread(fname)) - poses.append(np.array(frame['transform_matrix'])) - imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) - poses = np.array(poses).astype(np.float32) - counts.append(counts[-1] + imgs.shape[0]) - all_imgs.append(imgs) - all_poses.append(poses) - - i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] - - imgs = np.concatenate(all_imgs, 0) - poses = np.concatenate(all_poses, 0) - - H, W = imgs[0].shape[:2] - focal = float(meta['frames'][0]['intrinsic_matrix'][0][0]) - K = meta['frames'][0]['intrinsic_matrix'] - print(f"Focal: {focal}") - - render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) - - if half_res: - H = H//2 - W = W//2 - focal = focal/2. - - imgs_half_res = np.zeros((imgs.shape[0], H, W, 3)) - for i, img in enumerate(imgs): - imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) - imgs = imgs_half_res - # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() - - near = np.floor(min(metas['train']['near'], metas['test']['near'])) - far = np.ceil(max(metas['train']['far'], metas['test']['far'])) - return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far - - diff --git a/load_blender.py b/load_blender.py deleted file mode 100644 index 99daf8f1a..000000000 --- a/load_blender.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import torch -import numpy as np -import imageio -import json -import torch.nn.functional as F -import cv2 - - -trans_t = lambda t : torch.Tensor([ - [1,0,0,0], - [0,1,0,0], - [0,0,1,t], - [0,0,0,1]]).float() - -rot_phi = lambda phi : torch.Tensor([ - [1,0,0,0], - [0,np.cos(phi),-np.sin(phi),0], - [0,np.sin(phi), np.cos(phi),0], - [0,0,0,1]]).float() - -rot_theta = lambda th : torch.Tensor([ - [np.cos(th),0,-np.sin(th),0], - [0,1,0,0], - [np.sin(th),0, np.cos(th),0], - [0,0,0,1]]).float() - - -def pose_spherical(theta, phi, radius): - c2w = trans_t(radius) - c2w = rot_phi(phi/180.*np.pi) @ c2w - c2w = rot_theta(theta/180.*np.pi) @ c2w - c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w - return c2w - - -def load_blender_data(basedir, half_res=False, testskip=1): - splits = ['train', 'val', 'test'] - metas = {} - for s in splits: - with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: - metas[s] = json.load(fp) - - all_imgs = [] - all_poses = [] - counts = [0] - for s in splits: - meta = metas[s] - imgs = [] - poses = [] - if s=='train' or testskip==0: - skip = 1 - else: - skip = testskip - - for frame in meta['frames'][::skip]: - fname = os.path.join(basedir, frame['file_path'] + '.png') - imgs.append(imageio.imread(fname)) - poses.append(np.array(frame['transform_matrix'])) - imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) - poses = np.array(poses).astype(np.float32) - counts.append(counts[-1] + imgs.shape[0]) - all_imgs.append(imgs) - all_poses.append(poses) - - i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] - - imgs = np.concatenate(all_imgs, 0) - poses = np.concatenate(all_poses, 0) - - H, W = imgs[0].shape[:2] - camera_angle_x = float(meta['camera_angle_x']) - focal = .5 * W / np.tan(.5 * camera_angle_x) - - render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) - - if half_res: - H = H//2 - W = W//2 - focal = focal/2. - - imgs_half_res = np.zeros((imgs.shape[0], H, W, 4)) - for i, img in enumerate(imgs): - imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) - imgs = imgs_half_res - # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() - - - return imgs, poses, render_poses, [H, W, focal], i_split - - diff --git a/load_deepvoxels.py b/load_deepvoxels.py deleted file mode 100644 index deb2a9c51..000000000 --- a/load_deepvoxels.py +++ /dev/null @@ -1,110 +0,0 @@ -import os -import numpy as np -import imageio - - -def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8): - - - def parse_intrinsics(filepath, trgt_sidelength, invert_y=False): - # Get camera intrinsics - with open(filepath, 'r') as file: - f, cx, cy = list(map(float, file.readline().split()))[:3] - grid_barycenter = np.array(list(map(float, file.readline().split()))) - near_plane = float(file.readline()) - scale = float(file.readline()) - height, width = map(float, file.readline().split()) - - try: - world2cam_poses = int(file.readline()) - except ValueError: - world2cam_poses = None - - if world2cam_poses is None: - world2cam_poses = False - - world2cam_poses = bool(world2cam_poses) - - print(cx,cy,f,height,width) - - cx = cx / width * trgt_sidelength - cy = cy / height * trgt_sidelength - f = trgt_sidelength / height * f - - fx = f - if invert_y: - fy = -f - else: - fy = f - - # Build the intrinsic matrices - full_intrinsic = np.array([[fx, 0., cx, 0.], - [0., fy, cy, 0], - [0., 0, 1, 0], - [0, 0, 0, 1]]) - - return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses - - - def load_pose(filename): - assert os.path.isfile(filename) - nums = open(filename).read().split() - return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32) - - - H = 512 - W = 512 - deepvoxels_base = '{}/train/{}/'.format(basedir, scene) - - full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H) - print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses) - focal = full_intrinsic[0,0] - print(H, W, focal) - - - def dir2poses(posedir): - poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0) - transf = np.array([ - [1,0,0,0], - [0,-1,0,0], - [0,0,-1,0], - [0,0,0,1.], - ]) - poses = poses @ transf - poses = poses[:,:3,:4].astype(np.float32) - return poses - - posedir = os.path.join(deepvoxels_base, 'pose') - poses = dir2poses(posedir) - testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene)) - testposes = testposes[::testskip] - valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene)) - valposes = valposes[::testskip] - - imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')] - imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32) - - - testimgd = '{}/test/{}/rgb'.format(basedir, scene) - imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')] - testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) - - valimgd = '{}/validation/{}/rgb'.format(basedir, scene) - imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')] - valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) - - all_imgs = [imgs, valimgs, testimgs] - counts = [0] + [x.shape[0] for x in all_imgs] - counts = np.cumsum(counts) - i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] - - imgs = np.concatenate(all_imgs, 0) - poses = np.concatenate([poses, valposes, testposes], 0) - - render_poses = testposes - - print(poses.shape, imgs.shape) - - return imgs, poses, render_poses, [H,W,focal], i_split - - diff --git a/load_llff.py b/load_llff.py deleted file mode 100644 index 98b791637..000000000 --- a/load_llff.py +++ /dev/null @@ -1,319 +0,0 @@ -import numpy as np -import os, imageio - - -########## Slightly modified version of LLFF data loading code -########## see https://github.com/Fyusion/LLFF for original - -def _minify(basedir, factors=[], resolutions=[]): - needtoload = False - for r in factors: - imgdir = os.path.join(basedir, 'images_{}'.format(r)) - if not os.path.exists(imgdir): - needtoload = True - for r in resolutions: - imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) - if not os.path.exists(imgdir): - needtoload = True - if not needtoload: - return - - from shutil import copy - from subprocess import check_output - - imgdir = os.path.join(basedir, 'images') - imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] - imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] - imgdir_orig = imgdir - - wd = os.getcwd() - - for r in factors + resolutions: - if isinstance(r, int): - name = 'images_{}'.format(r) - resizearg = '{}%'.format(100./r) - else: - name = 'images_{}x{}'.format(r[1], r[0]) - resizearg = '{}x{}'.format(r[1], r[0]) - imgdir = os.path.join(basedir, name) - if os.path.exists(imgdir): - continue - - print('Minifying', r, basedir) - - os.makedirs(imgdir) - check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) - - ext = imgs[0].split('.')[-1] - args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) - print(args) - os.chdir(imgdir) - check_output(args, shell=True) - os.chdir(wd) - - if ext != 'png': - check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) - print('Removed duplicates') - print('Done') - - - - -def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): - - poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) - poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0]) - bds = poses_arr[:, -2:].transpose([1,0]) - - img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ - if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] - sh = imageio.imread(img0).shape - - sfx = '' - - if factor is not None: - sfx = '_{}'.format(factor) - _minify(basedir, factors=[factor]) - factor = factor - elif height is not None: - factor = sh[0] / float(height) - width = int(sh[1] / factor) - _minify(basedir, resolutions=[[height, width]]) - sfx = '_{}x{}'.format(width, height) - elif width is not None: - factor = sh[1] / float(width) - height = int(sh[0] / factor) - _minify(basedir, resolutions=[[height, width]]) - sfx = '_{}x{}'.format(width, height) - else: - factor = 1 - - imgdir = os.path.join(basedir, 'images' + sfx) - if not os.path.exists(imgdir): - print( imgdir, 'does not exist, returning' ) - return - - imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] - if poses.shape[-1] != len(imgfiles): - print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) ) - return - - sh = imageio.imread(imgfiles[0]).shape - poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) - poses[2, 4, :] = poses[2, 4, :] * 1./factor - - if not load_imgs: - return poses, bds - - def imread(f): - if f.endswith('png'): - return imageio.imread(f, ignoregamma=True) - else: - return imageio.imread(f) - - imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles] - imgs = np.stack(imgs, -1) - - print('Loaded image data', imgs.shape, poses[:,-1,0]) - return poses, bds, imgs - - - - - - -def normalize(x): - return x / np.linalg.norm(x) - -def viewmatrix(z, up, pos): - vec2 = normalize(z) - vec1_avg = up - vec0 = normalize(np.cross(vec1_avg, vec2)) - vec1 = normalize(np.cross(vec2, vec0)) - m = np.stack([vec0, vec1, vec2, pos], 1) - return m - -def ptstocam(pts, c2w): - tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0] - return tt - -def poses_avg(poses): - - hwf = poses[0, :3, -1:] - - center = poses[:, :3, 3].mean(0) - vec2 = normalize(poses[:, :3, 2].sum(0)) - up = poses[:, :3, 1].sum(0) - c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) - - return c2w - - - -def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): - render_poses = [] - rads = np.array(list(rads) + [1.]) - hwf = c2w[:,4:5] - - for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: - c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) - z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) - render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) - return render_poses - - - -def recenter_poses(poses): - - poses_ = poses+0 - bottom = np.reshape([0,0,0,1.], [1,4]) - c2w = poses_avg(poses) - c2w = np.concatenate([c2w[:3,:4], bottom], -2) - bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1]) - poses = np.concatenate([poses[:,:3,:4], bottom], -2) - - poses = np.linalg.inv(c2w) @ poses - poses_[:,:3,:4] = poses[:,:3,:4] - poses = poses_ - return poses - - -##################### - - -def spherify_poses(poses, bds): - - p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1) - - rays_d = poses[:,:3,2:3] - rays_o = poses[:,:3,3:4] - - def min_line_dist(rays_o, rays_d): - A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1]) - b_i = -A_i @ rays_o - pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0)) - return pt_mindist - - pt_mindist = min_line_dist(rays_o, rays_d) - - center = pt_mindist - up = (poses[:,:3,3] - center).mean(0) - - vec0 = normalize(up) - vec1 = normalize(np.cross([.1,.2,.3], vec0)) - vec2 = normalize(np.cross(vec0, vec1)) - pos = center - c2w = np.stack([vec1, vec2, vec0, pos], 1) - - poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4]) - - rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1))) - - sc = 1./rad - poses_reset[:,:3,3] *= sc - bds *= sc - rad *= sc - - centroid = np.mean(poses_reset[:,:3,3], 0) - zh = centroid[2] - radcircle = np.sqrt(rad**2-zh**2) - new_poses = [] - - for th in np.linspace(0.,2.*np.pi, 120): - - camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) - up = np.array([0,0,-1.]) - - vec2 = normalize(camorigin) - vec0 = normalize(np.cross(vec2, up)) - vec1 = normalize(np.cross(vec2, vec0)) - pos = camorigin - p = np.stack([vec0, vec1, vec2, pos], 1) - - new_poses.append(p) - - new_poses = np.stack(new_poses, 0) - - new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1) - poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1) - - return poses_reset, new_poses, bds - - -def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False): - - - poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x - print('Loaded', basedir, bds.min(), bds.max()) - - # Correct rotation matrix ordering and move variable dim to axis 0 - poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) - poses = np.moveaxis(poses, -1, 0).astype(np.float32) - imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) - images = imgs - bds = np.moveaxis(bds, -1, 0).astype(np.float32) - - # Rescale if bd_factor is provided - sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor) - poses[:,:3,3] *= sc - bds *= sc - - if recenter: - poses = recenter_poses(poses) - - if spherify: - poses, render_poses, bds = spherify_poses(poses, bds) - - else: - - c2w = poses_avg(poses) - print('recentered', c2w.shape) - print(c2w[:3,:4]) - - ## Get spiral - # Get average pose - up = normalize(poses[:, :3, 1].sum(0)) - - # Find a reasonable "focus depth" for this dataset - close_depth, inf_depth = bds.min()*.9, bds.max()*5. - dt = .75 - mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) - focal = mean_dz - - # Get radii for spiral path - shrink_factor = .8 - zdelta = close_depth * .2 - tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T - rads = np.percentile(np.abs(tt), 90, 0) - c2w_path = c2w - N_views = 120 - N_rots = 2 - if path_zflat: -# zloc = np.percentile(tt, 10, 0)[2] - zloc = -close_depth * .1 - c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2] - rads[2] = 0. - N_rots = 1 - N_views/=2 - - # Generate poses for spiral path - render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views) - - - render_poses = np.array(render_poses).astype(np.float32) - - c2w = poses_avg(poses) - print('Data:') - print(poses.shape, images.shape, bds.shape) - - dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1) - i_test = np.argmin(dists) - print('HOLDOUT view is', i_test) - - images = images.astype(np.float32) - poses = poses.astype(np.float32) - - return images, poses, bds, render_poses, i_test - - - diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 168d1b855..000000000 --- a/requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -torch==1.11.0 -torchvision>=0.9.1 -imageio -imageio-ffmpeg -matplotlib -configargparse -tensorboard>=2.0 -tqdm -opencv-python diff --git a/run_nerf.py b/run_nerf.py deleted file mode 100644 index bc270be86..000000000 --- a/run_nerf.py +++ /dev/null @@ -1,878 +0,0 @@ -import os, sys -import numpy as np -import imageio -import json -import random -import time -import torch -import torch.nn as nn -import torch.nn.functional as F -from tqdm import tqdm, trange - -import matplotlib.pyplot as plt - -from run_nerf_helpers import * - -from load_llff import load_llff_data -from load_deepvoxels import load_dv_data -from load_blender import load_blender_data -from load_LINEMOD import load_LINEMOD_data - - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -np.random.seed(0) -DEBUG = False - - -def batchify(fn, chunk): - """Constructs a version of 'fn' that applies to smaller batches. - """ - if chunk is None: - return fn - def ret(inputs): - return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) - return ret - - -def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): - """Prepares inputs and applies network 'fn'. - """ - inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) - embedded = embed_fn(inputs_flat) - - if viewdirs is not None: - input_dirs = viewdirs[:,None].expand(inputs.shape) - input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) - embedded_dirs = embeddirs_fn(input_dirs_flat) - embedded = torch.cat([embedded, embedded_dirs], -1) - - outputs_flat = batchify(fn, netchunk)(embedded) - outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) - return outputs - - -def batchify_rays(rays_flat, chunk=1024*32, **kwargs): - """Render rays in smaller minibatches to avoid OOM. - """ - all_ret = {} - for i in range(0, rays_flat.shape[0], chunk): - ret = render_rays(rays_flat[i:i+chunk], **kwargs) - for k in ret: - if k not in all_ret: - all_ret[k] = [] - all_ret[k].append(ret[k]) - - all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret} - return all_ret - - -def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True, - near=0., far=1., - use_viewdirs=False, c2w_staticcam=None, - **kwargs): - """Render rays - Args: - H: int. Height of image in pixels. - W: int. Width of image in pixels. - focal: float. Focal length of pinhole camera. - chunk: int. Maximum number of rays to process simultaneously. Used to - control maximum memory usage. Does not affect final results. - rays: array of shape [2, batch_size, 3]. Ray origin and direction for - each example in batch. - c2w: array of shape [3, 4]. Camera-to-world transformation matrix. - ndc: bool. If True, represent ray origin, direction in NDC coordinates. - near: float or array of shape [batch_size]. Nearest distance for a ray. - far: float or array of shape [batch_size]. Farthest distance for a ray. - use_viewdirs: bool. If True, use viewing direction of a point in space in model. - c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for - camera while using other c2w argument for viewing directions. - Returns: - rgb_map: [batch_size, 3]. Predicted RGB values for rays. - disp_map: [batch_size]. Disparity map. Inverse of depth. - acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. - extras: dict with everything returned by render_rays(). - """ - if c2w is not None: - # special case to render full image - rays_o, rays_d = get_rays(H, W, K, c2w) - else: - # use provided ray batch - rays_o, rays_d = rays - - if use_viewdirs: - # provide ray directions as input - viewdirs = rays_d - if c2w_staticcam is not None: - # special case to visualize effect of viewdirs - rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) - viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) - viewdirs = torch.reshape(viewdirs, [-1,3]).float() - - sh = rays_d.shape # [..., 3] - if ndc: - # for forward facing scenes - rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) - - # Create ray batch - rays_o = torch.reshape(rays_o, [-1,3]).float() - rays_d = torch.reshape(rays_d, [-1,3]).float() - - near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1]) - rays = torch.cat([rays_o, rays_d, near, far], -1) - if use_viewdirs: - rays = torch.cat([rays, viewdirs], -1) - - # Render and reshape - all_ret = batchify_rays(rays, chunk, **kwargs) - for k in all_ret: - k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) - all_ret[k] = torch.reshape(all_ret[k], k_sh) - - k_extract = ['rgb_map', 'disp_map', 'acc_map'] - ret_list = [all_ret[k] for k in k_extract] - ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract} - return ret_list + [ret_dict] - - -def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): - - H, W, focal = hwf - - if render_factor!=0: - # Render downsampled for speed - H = H//render_factor - W = W//render_factor - focal = focal/render_factor - - rgbs = [] - disps = [] - - t = time.time() - for i, c2w in enumerate(tqdm(render_poses)): - print(i, time.time() - t) - t = time.time() - rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs) - rgbs.append(rgb.cpu().numpy()) - disps.append(disp.cpu().numpy()) - if i==0: - print(rgb.shape, disp.shape) - - """ - if gt_imgs is not None and render_factor==0: - p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i]))) - print(p) - """ - - if savedir is not None: - rgb8 = to8b(rgbs[-1]) - filename = os.path.join(savedir, '{:03d}.png'.format(i)) - imageio.imwrite(filename, rgb8) - - - rgbs = np.stack(rgbs, 0) - disps = np.stack(disps, 0) - - return rgbs, disps - - -def create_nerf(args): - """Instantiate NeRF's MLP model. - """ - embed_fn, input_ch = get_embedder(args.multires, args.i_embed) - - input_ch_views = 0 - embeddirs_fn = None - if args.use_viewdirs: - embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) - output_ch = 5 if args.N_importance > 0 else 4 - skips = [4] - model = NeRF(D=args.netdepth, W=args.netwidth, - input_ch=input_ch, output_ch=output_ch, skips=skips, - input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) - grad_vars = list(model.parameters()) - - model_fine = None - if args.N_importance > 0: - model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, - input_ch=input_ch, output_ch=output_ch, skips=skips, - input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) - grad_vars += list(model_fine.parameters()) - - network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn, - embed_fn=embed_fn, - embeddirs_fn=embeddirs_fn, - netchunk=args.netchunk) - - # Create optimizer - optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) - - start = 0 - basedir = args.basedir - expname = args.expname - - ########################## - - # Load checkpoints - if args.ft_path is not None and args.ft_path!='None': - ckpts = [args.ft_path] - else: - ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f] - - print('Found ckpts', ckpts) - if len(ckpts) > 0 and not args.no_reload: - ckpt_path = ckpts[-1] - print('Reloading from', ckpt_path) - ckpt = torch.load(ckpt_path) - - start = ckpt['global_step'] - optimizer.load_state_dict(ckpt['optimizer_state_dict']) - - # Load model - model.load_state_dict(ckpt['network_fn_state_dict']) - if model_fine is not None: - model_fine.load_state_dict(ckpt['network_fine_state_dict']) - - ########################## - - render_kwargs_train = { - 'network_query_fn' : network_query_fn, - 'perturb' : args.perturb, - 'N_importance' : args.N_importance, - 'network_fine' : model_fine, - 'N_samples' : args.N_samples, - 'network_fn' : model, - 'use_viewdirs' : args.use_viewdirs, - 'white_bkgd' : args.white_bkgd, - 'raw_noise_std' : args.raw_noise_std, - } - - # NDC only good for LLFF-style forward facing data - if args.dataset_type != 'llff' or args.no_ndc: - print('Not ndc!') - render_kwargs_train['ndc'] = False - render_kwargs_train['lindisp'] = args.lindisp - - render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train} - render_kwargs_test['perturb'] = False - render_kwargs_test['raw_noise_std'] = 0. - - return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer - - -def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False): - """Transforms model's predictions to semantically meaningful values. - Args: - raw: [num_rays, num_samples along ray, 4]. Prediction from model. - z_vals: [num_rays, num_samples along ray]. Integration time. - rays_d: [num_rays, 3]. Direction of each ray. - Returns: - rgb_map: [num_rays, 3]. Estimated RGB color of a ray. - disp_map: [num_rays]. Disparity map. Inverse of depth map. - acc_map: [num_rays]. Sum of weights along each ray. - weights: [num_rays, num_samples]. Weights assigned to each sampled color. - depth_map: [num_rays]. Estimated distance to object. - """ - raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists) - - dists = z_vals[...,1:] - z_vals[...,:-1] - dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples] - - dists = dists * torch.norm(rays_d[...,None,:], dim=-1) - - rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3] - noise = 0. - if raw_noise_std > 0.: - noise = torch.randn(raw[...,3].shape) * raw_noise_std - - # Overwrite randomly sampled data if pytest - if pytest: - np.random.seed(0) - noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std - noise = torch.Tensor(noise) - - alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples] - # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) - weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1] - rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3] - - depth_map = torch.sum(weights * z_vals, -1) - disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1)) - acc_map = torch.sum(weights, -1) - - if white_bkgd: - rgb_map = rgb_map + (1.-acc_map[...,None]) - - return rgb_map, disp_map, acc_map, weights, depth_map - - -def render_rays(ray_batch, - network_fn, - network_query_fn, - N_samples, - retraw=False, - lindisp=False, - perturb=0., - N_importance=0, - network_fine=None, - white_bkgd=False, - raw_noise_std=0., - verbose=False, - pytest=False): - """Volumetric rendering. - Args: - ray_batch: array of shape [batch_size, ...]. All information necessary - for sampling along a ray, including: ray origin, ray direction, min - dist, max dist, and unit-magnitude viewing direction. - network_fn: function. Model for predicting RGB and density at each point - in space. - network_query_fn: function used for passing queries to network_fn. - N_samples: int. Number of different times to sample along each ray. - retraw: bool. If True, include model's raw, unprocessed predictions. - lindisp: bool. If True, sample linearly in inverse depth rather than in depth. - perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified - random points in time. - N_importance: int. Number of additional times to sample along each ray. - These samples are only passed to network_fine. - network_fine: "fine" network with same spec as network_fn. - white_bkgd: bool. If True, assume a white background. - raw_noise_std: ... - verbose: bool. If True, print more debugging info. - Returns: - rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. - disp_map: [num_rays]. Disparity map. 1 / depth. - acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. - raw: [num_rays, num_samples, 4]. Raw predictions from model. - rgb0: See rgb_map. Output for coarse model. - disp0: See disp_map. Output for coarse model. - acc0: See acc_map. Output for coarse model. - z_std: [num_rays]. Standard deviation of distances along ray for each - sample. - """ - N_rays = ray_batch.shape[0] - rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each - viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None - bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2]) - near, far = bounds[...,0], bounds[...,1] # [-1,1] - - t_vals = torch.linspace(0., 1., steps=N_samples) - if not lindisp: - z_vals = near * (1.-t_vals) + far * (t_vals) - else: - z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) - - z_vals = z_vals.expand([N_rays, N_samples]) - - if perturb > 0.: - # get intervals between samples - mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) - upper = torch.cat([mids, z_vals[...,-1:]], -1) - lower = torch.cat([z_vals[...,:1], mids], -1) - # stratified samples in those intervals - t_rand = torch.rand(z_vals.shape) - - # Pytest, overwrite u with numpy's fixed random numbers - if pytest: - np.random.seed(0) - t_rand = np.random.rand(*list(z_vals.shape)) - t_rand = torch.Tensor(t_rand) - - z_vals = lower + (upper - lower) * t_rand - - pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3] - - -# raw = run_network(pts) - raw = network_query_fn(pts, viewdirs, network_fn) - rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) - - if N_importance > 0: - - rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map - - z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1]) - z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest) - z_samples = z_samples.detach() - - z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) - pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3] - - run_fn = network_fn if network_fine is None else network_fine -# raw = run_network(pts, fn=run_fn) - raw = network_query_fn(pts, viewdirs, run_fn) - - rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) - - ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map} - if retraw: - ret['raw'] = raw - if N_importance > 0: - ret['rgb0'] = rgb_map_0 - ret['disp0'] = disp_map_0 - ret['acc0'] = acc_map_0 - ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] - - for k in ret: - if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: - print(f"! [Numerical Error] {k} contains nan or inf.") - - return ret - - -def config_parser(): - - import configargparse - parser = configargparse.ArgumentParser() - parser.add_argument('--config', is_config_file=True, - help='config file path') - parser.add_argument("--expname", type=str, - help='experiment name') - parser.add_argument("--basedir", type=str, default='./logs/', - help='where to store ckpts and logs') - parser.add_argument("--datadir", type=str, default='./data/llff/fern', - help='input data directory') - - # training options - parser.add_argument("--netdepth", type=int, default=8, - help='layers in network') - parser.add_argument("--netwidth", type=int, default=256, - help='channels per layer') - parser.add_argument("--netdepth_fine", type=int, default=8, - help='layers in fine network') - parser.add_argument("--netwidth_fine", type=int, default=256, - help='channels per layer in fine network') - parser.add_argument("--N_rand", type=int, default=32*32*4, - help='batch size (number of random rays per gradient step)') - parser.add_argument("--lrate", type=float, default=5e-4, - help='learning rate') - parser.add_argument("--lrate_decay", type=int, default=250, - help='exponential learning rate decay (in 1000 steps)') - parser.add_argument("--chunk", type=int, default=1024*32, - help='number of rays processed in parallel, decrease if running out of memory') - parser.add_argument("--netchunk", type=int, default=1024*64, - help='number of pts sent through network in parallel, decrease if running out of memory') - parser.add_argument("--no_batching", action='store_true', - help='only take random rays from 1 image at a time') - parser.add_argument("--no_reload", action='store_true', - help='do not reload weights from saved ckpt') - parser.add_argument("--ft_path", type=str, default=None, - help='specific weights npy file to reload for coarse network') - - # rendering options - parser.add_argument("--N_samples", type=int, default=64, - help='number of coarse samples per ray') - parser.add_argument("--N_importance", type=int, default=0, - help='number of additional fine samples per ray') - parser.add_argument("--perturb", type=float, default=1., - help='set to 0. for no jitter, 1. for jitter') - parser.add_argument("--use_viewdirs", action='store_true', - help='use full 5D input instead of 3D') - parser.add_argument("--i_embed", type=int, default=0, - help='set 0 for default positional encoding, -1 for none') - parser.add_argument("--multires", type=int, default=10, - help='log2 of max freq for positional encoding (3D location)') - parser.add_argument("--multires_views", type=int, default=4, - help='log2 of max freq for positional encoding (2D direction)') - parser.add_argument("--raw_noise_std", type=float, default=0., - help='std dev of noise added to regularize sigma_a output, 1e0 recommended') - - parser.add_argument("--render_only", action='store_true', - help='do not optimize, reload weights and render out render_poses path') - parser.add_argument("--render_test", action='store_true', - help='render the test set instead of render_poses path') - parser.add_argument("--render_factor", type=int, default=0, - help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') - - # training options - parser.add_argument("--precrop_iters", type=int, default=0, - help='number of steps to train on central crops') - parser.add_argument("--precrop_frac", type=float, - default=.5, help='fraction of img taken for central crops') - - # dataset options - parser.add_argument("--dataset_type", type=str, default='llff', - help='options: llff / blender / deepvoxels') - parser.add_argument("--testskip", type=int, default=8, - help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') - - ## deepvoxels flags - parser.add_argument("--shape", type=str, default='greek', - help='options : armchair / cube / greek / vase') - - ## blender flags - parser.add_argument("--white_bkgd", action='store_true', - help='set to render synthetic data on a white bkgd (always use for dvoxels)') - parser.add_argument("--half_res", action='store_true', - help='load blender synthetic data at 400x400 instead of 800x800') - - ## llff flags - parser.add_argument("--factor", type=int, default=8, - help='downsample factor for LLFF images') - parser.add_argument("--no_ndc", action='store_true', - help='do not use normalized device coordinates (set for non-forward facing scenes)') - parser.add_argument("--lindisp", action='store_true', - help='sampling linearly in disparity rather than depth') - parser.add_argument("--spherify", action='store_true', - help='set for spherical 360 scenes') - parser.add_argument("--llffhold", type=int, default=8, - help='will take every 1/N images as LLFF test set, paper uses 8') - - # logging/saving options - parser.add_argument("--i_print", type=int, default=100, - help='frequency of console printout and metric loggin') - parser.add_argument("--i_img", type=int, default=500, - help='frequency of tensorboard image logging') - parser.add_argument("--i_weights", type=int, default=10000, - help='frequency of weight ckpt saving') - parser.add_argument("--i_testset", type=int, default=50000, - help='frequency of testset saving') - parser.add_argument("--i_video", type=int, default=50000, - help='frequency of render_poses video saving') - - return parser - - -def train(): - - parser = config_parser() - args = parser.parse_args() - - # Load data - K = None - if args.dataset_type == 'llff': - images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, - recenter=True, bd_factor=.75, - spherify=args.spherify) - hwf = poses[0,:3,-1] - poses = poses[:,:3,:4] - print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) - if not isinstance(i_test, list): - i_test = [i_test] - - if args.llffhold > 0: - print('Auto LLFF holdout,', args.llffhold) - i_test = np.arange(images.shape[0])[::args.llffhold] - - i_val = i_test - i_train = np.array([i for i in np.arange(int(images.shape[0])) if - (i not in i_test and i not in i_val)]) - - print('DEFINING BOUNDS') - if args.no_ndc: - near = np.ndarray.min(bds) * .9 - far = np.ndarray.max(bds) * 1. - - else: - near = 0. - far = 1. - print('NEAR FAR', near, far) - - elif args.dataset_type == 'blender': - images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip) - print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) - i_train, i_val, i_test = i_split - - near = 2. - far = 6. - - if args.white_bkgd: - images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) - else: - images = images[...,:3] - - elif args.dataset_type == 'LINEMOD': - images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip) - print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}') - print(f'[CHECK HERE] near: {near}, far: {far}.') - i_train, i_val, i_test = i_split - - if args.white_bkgd: - images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) - else: - images = images[...,:3] - - elif args.dataset_type == 'deepvoxels': - - images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, - basedir=args.datadir, - testskip=args.testskip) - - print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) - i_train, i_val, i_test = i_split - - hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1)) - near = hemi_R-1. - far = hemi_R+1. - - else: - print('Unknown dataset type', args.dataset_type, 'exiting') - return - - # Cast intrinsics to right types - H, W, focal = hwf - H, W = int(H), int(W) - hwf = [H, W, focal] - - if K is None: - K = np.array([ - [focal, 0, 0.5*W], - [0, focal, 0.5*H], - [0, 0, 1] - ]) - - if args.render_test: - render_poses = np.array(poses[i_test]) - - # Create log dir and copy the config file - basedir = args.basedir - expname = args.expname - os.makedirs(os.path.join(basedir, expname), exist_ok=True) - f = os.path.join(basedir, expname, 'args.txt') - with open(f, 'w') as file: - for arg in sorted(vars(args)): - attr = getattr(args, arg) - file.write('{} = {}\n'.format(arg, attr)) - if args.config is not None: - f = os.path.join(basedir, expname, 'config.txt') - with open(f, 'w') as file: - file.write(open(args.config, 'r').read()) - - # Create nerf model - render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args) - global_step = start - - bds_dict = { - 'near' : near, - 'far' : far, - } - render_kwargs_train.update(bds_dict) - render_kwargs_test.update(bds_dict) - - # Move testing data to GPU - render_poses = torch.Tensor(render_poses).to(device) - - # Short circuit if only rendering out from trained model - if args.render_only: - print('RENDER ONLY') - with torch.no_grad(): - if args.render_test: - # render_test switches to test poses - images = images[i_test] - else: - # Default is smoother render_poses path - images = None - - testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start)) - os.makedirs(testsavedir, exist_ok=True) - print('test poses shape', render_poses.shape) - - rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) - print('Done rendering', testsavedir) - imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) - - return - - # Prepare raybatch tensor if batching random rays - N_rand = args.N_rand - use_batching = not args.no_batching - if use_batching: - # For random ray batching - print('get rays') - rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3] - print('done, concats') - rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3] - rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3] - rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only - rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3] - rays_rgb = rays_rgb.astype(np.float32) - print('shuffle rays') - np.random.shuffle(rays_rgb) - - print('done') - i_batch = 0 - - # Move training data to GPU - if use_batching: - images = torch.Tensor(images).to(device) - poses = torch.Tensor(poses).to(device) - if use_batching: - rays_rgb = torch.Tensor(rays_rgb).to(device) - - - N_iters = 200000 + 1 - print('Begin') - print('TRAIN views are', i_train) - print('TEST views are', i_test) - print('VAL views are', i_val) - - # Summary writers - # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) - - start = start + 1 - for i in trange(start, N_iters): - time0 = time.time() - - # Sample random ray batch - if use_batching: - # Random over all images - batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?] - batch = torch.transpose(batch, 0, 1) - batch_rays, target_s = batch[:2], batch[2] - - i_batch += N_rand - if i_batch >= rays_rgb.shape[0]: - print("Shuffle data after an epoch!") - rand_idx = torch.randperm(rays_rgb.shape[0]) - rays_rgb = rays_rgb[rand_idx] - i_batch = 0 - - else: - # Random from one image - img_i = np.random.choice(i_train) - target = images[img_i] - target = torch.Tensor(target).to(device) - pose = poses[img_i, :3,:4] - - if N_rand is not None: - rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3) - - if i < args.precrop_iters: - dH = int(H//2 * args.precrop_frac) - dW = int(W//2 * args.precrop_frac) - coords = torch.stack( - torch.meshgrid( - torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), - torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW) - ), -1) - if i == start: - print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}") - else: - coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2) - - coords = torch.reshape(coords, [-1,2]) # (H * W, 2) - select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) - select_coords = coords[select_inds].long() # (N_rand, 2) - rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) - rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) - batch_rays = torch.stack([rays_o, rays_d], 0) - target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) - - ##### Core optimization loop ##### - rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays, - verbose=i < 10, retraw=True, - **render_kwargs_train) - - optimizer.zero_grad() - img_loss = img2mse(rgb, target_s) - trans = extras['raw'][...,-1] - loss = img_loss - psnr = mse2psnr(img_loss) - - if 'rgb0' in extras: - img_loss0 = img2mse(extras['rgb0'], target_s) - loss = loss + img_loss0 - psnr0 = mse2psnr(img_loss0) - - loss.backward() - optimizer.step() - - # NOTE: IMPORTANT! - ### update learning rate ### - decay_rate = 0.1 - decay_steps = args.lrate_decay * 1000 - new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) - for param_group in optimizer.param_groups: - param_group['lr'] = new_lrate - ################################ - - dt = time.time()-time0 - # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") - ##### end ##### - - # Rest is logging - if i%args.i_weights==0: - path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) - torch.save({ - 'global_step': global_step, - 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), - 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - }, path) - print('Saved checkpoints at', path) - - if i%args.i_video==0 and i > 0: - # Turn on testing mode - with torch.no_grad(): - rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test) - print('Done, saving', rgbs.shape, disps.shape) - moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) - imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) - imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) - - # if args.use_viewdirs: - # render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4] - # with torch.no_grad(): - # rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test) - # render_kwargs_test['c2w_staticcam'] = None - # imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8) - - if i%args.i_testset==0 and i > 0: - testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) - os.makedirs(testsavedir, exist_ok=True) - print('test poses shape', poses[i_test].shape) - with torch.no_grad(): - render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) - print('Saved test set') - - - - if i%args.i_print==0: - tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") - """ - print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy()) - print('iter time {:.05f}'.format(dt)) - - with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print): - tf.contrib.summary.scalar('loss', loss) - tf.contrib.summary.scalar('psnr', psnr) - tf.contrib.summary.histogram('tran', trans) - if args.N_importance > 0: - tf.contrib.summary.scalar('psnr0', psnr0) - - - if i%args.i_img==0: - - # Log a rendered validation view to Tensorboard - img_i=np.random.choice(i_val) - target = images[img_i] - pose = poses[img_i, :3,:4] - with torch.no_grad(): - rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose, - **render_kwargs_test) - - psnr = mse2psnr(img2mse(rgb, target)) - - with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): - - tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis]) - tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis]) - tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis]) - - tf.contrib.summary.scalar('psnr_holdout', psnr) - tf.contrib.summary.image('rgb_holdout', target[tf.newaxis]) - - - if args.N_importance > 0: - - with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img): - tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis]) - tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis]) - tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis]) - """ - - global_step += 1 - - -if __name__=='__main__': - torch.set_default_tensor_type('torch.cuda.FloatTensor') - - train() diff --git a/run_nerf_helpers.py b/run_nerf_helpers.py deleted file mode 100644 index bc6ee779d..000000000 --- a/run_nerf_helpers.py +++ /dev/null @@ -1,239 +0,0 @@ -import torch -# torch.autograd.set_detect_anomaly(True) -import torch.nn as nn -import torch.nn.functional as F -import numpy as np - - -# Misc -img2mse = lambda x, y : torch.mean((x - y) ** 2) -mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) -to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) - - -# Positional encoding (section 5.1) -class Embedder: - def __init__(self, **kwargs): - self.kwargs = kwargs - self.create_embedding_fn() - - def create_embedding_fn(self): - embed_fns = [] - d = self.kwargs['input_dims'] - out_dim = 0 - if self.kwargs['include_input']: - embed_fns.append(lambda x : x) - out_dim += d - - max_freq = self.kwargs['max_freq_log2'] - N_freqs = self.kwargs['num_freqs'] - - if self.kwargs['log_sampling']: - freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) - else: - freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) - - for freq in freq_bands: - for p_fn in self.kwargs['periodic_fns']: - embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) - out_dim += d - - self.embed_fns = embed_fns - self.out_dim = out_dim - - def embed(self, inputs): - return torch.cat([fn(inputs) for fn in self.embed_fns], -1) - - -def get_embedder(multires, i=0): - if i == -1: - return nn.Identity(), 3 - - embed_kwargs = { - 'include_input' : True, - 'input_dims' : 3, - 'max_freq_log2' : multires-1, - 'num_freqs' : multires, - 'log_sampling' : True, - 'periodic_fns' : [torch.sin, torch.cos], - } - - embedder_obj = Embedder(**embed_kwargs) - embed = lambda x, eo=embedder_obj : eo.embed(x) - return embed, embedder_obj.out_dim - - -# Model -class NeRF(nn.Module): - def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False): - """ - """ - super(NeRF, self).__init__() - self.D = D - self.W = W - self.input_ch = input_ch - self.input_ch_views = input_ch_views - self.skips = skips - self.use_viewdirs = use_viewdirs - - self.pts_linears = nn.ModuleList( - [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)]) - - ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) - self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)]) - - ### Implementation according to the paper - # self.views_linears = nn.ModuleList( - # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) - - if use_viewdirs: - self.feature_linear = nn.Linear(W, W) - self.alpha_linear = nn.Linear(W, 1) - self.rgb_linear = nn.Linear(W//2, 3) - else: - self.output_linear = nn.Linear(W, output_ch) - - def forward(self, x): - input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) - h = input_pts - for i, l in enumerate(self.pts_linears): - h = self.pts_linears[i](h) - h = F.relu(h) - if i in self.skips: - h = torch.cat([input_pts, h], -1) - - if self.use_viewdirs: - alpha = self.alpha_linear(h) - feature = self.feature_linear(h) - h = torch.cat([feature, input_views], -1) - - for i, l in enumerate(self.views_linears): - h = self.views_linears[i](h) - h = F.relu(h) - - rgb = self.rgb_linear(h) - outputs = torch.cat([rgb, alpha], -1) - else: - outputs = self.output_linear(h) - - return outputs - - def load_weights_from_keras(self, weights): - assert self.use_viewdirs, "Not implemented if use_viewdirs=False" - - # Load pts_linears - for i in range(self.D): - idx_pts_linears = 2 * i - self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears])) - self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1])) - - # Load feature_linear - idx_feature_linear = 2 * self.D - self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear])) - self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1])) - - # Load views_linears - idx_views_linears = 2 * self.D + 2 - self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears])) - self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1])) - - # Load rgb_linear - idx_rbg_linear = 2 * self.D + 4 - self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear])) - self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1])) - - # Load alpha_linear - idx_alpha_linear = 2 * self.D + 6 - self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear])) - self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1])) - - - -# Ray helpers -def get_rays(H, W, K, c2w): - i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij' - i = i.t() - j = j.t() - dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1) - # Rotate ray directions from camera frame to the world frame - rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] - # Translate camera frame's origin to the world frame. It is the origin of all rays. - rays_o = c2w[:3,-1].expand(rays_d.shape) - return rays_o, rays_d - - -def get_rays_np(H, W, K, c2w): - i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') - dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1) - # Rotate ray directions from camera frame to the world frame - rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] - # Translate camera frame's origin to the world frame. It is the origin of all rays. - rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) - return rays_o, rays_d - - -def ndc_rays(H, W, focal, near, rays_o, rays_d): - # Shift ray origins to near plane - t = -(near + rays_o[...,2]) / rays_d[...,2] - rays_o = rays_o + t[...,None] * rays_d - - # Projection - o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] - o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] - o2 = 1. + 2. * near / rays_o[...,2] - - d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) - d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) - d2 = -2. * near / rays_o[...,2] - - rays_o = torch.stack([o0,o1,o2], -1) - rays_d = torch.stack([d0,d1,d2], -1) - - return rays_o, rays_d - - -# Hierarchical sampling (section 5.2) -def sample_pdf(bins, weights, N_samples, det=False, pytest=False): - # Get pdf - weights = weights + 1e-5 # prevent nans - pdf = weights / torch.sum(weights, -1, keepdim=True) - cdf = torch.cumsum(pdf, -1) - cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins)) - - # Take uniform samples - if det: - u = torch.linspace(0., 1., steps=N_samples) - u = u.expand(list(cdf.shape[:-1]) + [N_samples]) - else: - u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) - - # Pytest, overwrite u with numpy's fixed random numbers - if pytest: - np.random.seed(0) - new_shape = list(cdf.shape[:-1]) + [N_samples] - if det: - u = np.linspace(0., 1., N_samples) - u = np.broadcast_to(u, new_shape) - else: - u = np.random.rand(*new_shape) - u = torch.Tensor(u) - - # Invert CDF - u = u.contiguous() - inds = torch.searchsorted(cdf, u, right=True) - below = torch.max(torch.zeros_like(inds-1), inds-1) - above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) - inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) - - # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) - # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) - matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] - cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) - bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) - - denom = (cdf_g[...,1]-cdf_g[...,0]) - denom = torch.where(denom<1e-5, torch.ones_like(denom), denom) - t = (u-cdf_g[...,0])/denom - samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) - - return samples diff --git a/src/dataset.cpp b/src/dataset.cpp new file mode 100644 index 000000000..6ffe35180 --- /dev/null +++ b/src/dataset.cpp @@ -0,0 +1,152 @@ +#include "nerf/dataset.hpp" +#include +#include +#include +#include + +namespace nerf { + +Dataset::Dataset(const std::string& datadir, + const std::string& dataset_type, + int factor, + bool use_viewdirs, + bool white_bkgd) + : datadir_(datadir), + dataset_type_(dataset_type), + factor_(factor), + use_viewdirs_(use_viewdirs), + white_bkgd_(white_bkgd) { + + if (dataset_type_ == "llff") { + load_llff_data(); + } else if (dataset_type_ == "blender") { + load_blender_data(); + } else { + throw std::runtime_error("Unknown dataset type: " + dataset_type_); + } +} + +void Dataset::load_llff_data() { + // Load images + auto images_dir = std::filesystem::path(datadir_) / "images"; + std::vector images; + for (const auto& entry : std::filesystem::directory_iterator(images_dir)) { + if (entry.path().extension() == ".png" || entry.path().extension() == ".jpg") { + images.push_back(load_image(entry.path().string())); + } + } + + // Stack images + auto images_tensor = torch::stack(images); + + // Load poses + auto poses = load_poses((std::filesystem::path(datadir_) / "poses_bounds.npy").string()); + + // Get image dimensions + H_ = images_tensor.size(1); + W_ = images_tensor.size(2); + + // Load camera parameters + auto K = load_poses((std::filesystem::path(datadir_) / "hwf.npy").string()); + focal_ = K[0].item(); + K_ = torch::tensor({ + {focal_, 0, W_ / 2.0f}, + {0, focal_, H_ / 2.0f}, + {0, 0, 1} + }); + + // Set near and far planes + near_ = torch::tensor(0.0f); + far_ = torch::tensor(1.0f); +} + +void Dataset::load_blender_data() { + // Load images + auto images_dir = std::filesystem::path(datadir_) / "train"; + std::vector images; + for (const auto& entry : std::filesystem::directory_iterator(images_dir)) { + if (entry.path().extension() == ".png" || entry.path().extension() == ".jpg") { + images.push_back(load_image(entry.path().string())); + } + } + + // Stack images + auto images_tensor = torch::stack(images); + + // Load poses + auto poses = load_poses((std::filesystem::path(datadir_) / "transforms_train.json").string()); + + // Get image dimensions + H_ = images_tensor.size(1); + W_ = images_tensor.size(2); + + // Set camera parameters + focal_ = 138.88887889922103f; // Default focal length for Blender dataset + K_ = torch::tensor({ + {focal_, 0, W_ / 2.0f}, + {0, focal_, H_ / 2.0f}, + {0, 0, 1} + }); + + // Set near and far planes + near_ = torch::tensor(2.0f); + far_ = torch::tensor(6.0f); +} + +torch::Tensor Dataset::load_image(const std::string& path) { + cv::Mat img = cv::imread(path); + cv::cvtColor(img, img, cv::COLOR_BGR2RGB); + + // Resize if needed + if (factor_ > 1) { + cv::resize(img, img, cv::Size(), 1.0f / factor_, 1.0f / factor_); + } + + // Convert to float and normalize + cv::Mat float_img; + img.convertTo(float_img, CV_32F, 1.0f / 255.0f); + + // Convert to torch tensor + auto tensor = torch::from_blob(float_img.data, {img.rows, img.cols, 3}, torch::kFloat32); + return tensor.clone(); // Clone to ensure memory ownership +} + +torch::Tensor Dataset::load_poses(const std::string& path) { + // This is a simplified version. In practice, you'll need to implement + // proper loading of different pose file formats (npy, json, etc.) + std::ifstream file(path, std::ios::binary); + if (!file) { + throw std::runtime_error("Could not open pose file: " + path); + } + + // Read file header to determine format + std::string header; + std::getline(file, header); + + if (path.find(".npy") != std::string::npos) { + // Load numpy array + // Implementation depends on your numpy file reading library + throw std::runtime_error("Numpy file loading not implemented"); + } else if (path.find(".json") != std::string::npos) { + // Load JSON + // Implementation depends on your JSON parsing library + throw std::runtime_error("JSON file loading not implemented"); + } else { + throw std::runtime_error("Unsupported pose file format: " + path); + } +} + +std::tuple +Dataset::get_data() { + // This is a placeholder. In practice, you'll need to implement + // proper data loading and preprocessing based on your dataset type + return std::make_tuple( + torch::zeros({1}), // images + torch::zeros({1}), // poses + torch::zeros({1}), // render_poses + torch::zeros({1}), // hwf + torch::zeros({1}) // i_split + ); +} + +} // namespace nerf \ No newline at end of file diff --git a/src/model.cpp b/src/model.cpp new file mode 100644 index 000000000..3fbf39477 --- /dev/null +++ b/src/model.cpp @@ -0,0 +1,111 @@ +#include "nerf/model.hpp" +#include + +namespace nerf { + +NeRFModel::NeRFModel(int netdepth, int netwidth, int netdepth_fine, int netwidth_fine, + int multires, int multires_views, bool use_viewdirs) + : netdepth_(netdepth), netwidth_(netwidth), + netdepth_fine_(netdepth_fine), netwidth_fine_(netwidth_fine), + multires_(multires), multires_views_(multires_views), + use_viewdirs_(use_viewdirs) { + + // Create the main network + std::vector layers; + int input_ch = 3 + 2 * multires_ * 3; // 3D position + positional encoding + + for (int i = 0; i < netdepth_; i++) { + int in_channels = (i == 0) ? input_ch : netwidth_; + layers.push_back(torch::nn::Linear(in_channels, netwidth_)); + register_module("lin" + std::to_string(i), layers.back()); + } + + net_ = torch::nn::Sequential(layers); + register_module("net", net_); + + // Create the fine network + std::vector fine_layers; + for (int i = 0; i < netdepth_fine_; i++) { + int in_channels = (i == 0) ? input_ch : netwidth_fine_; + fine_layers.push_back(torch::nn::Linear(in_channels, netwidth_fine_)); + register_module("lin_fine" + std::to_string(i), fine_layers.back()); + } + + net_fine_ = torch::nn::Sequential(fine_layers); + register_module("net_fine", net_fine_); + + // Create view-dependent networks if needed + if (use_viewdirs_) { + int input_ch_views = 3 + 2 * multires_views_ * 3; // View direction + positional encoding + viewdirs_net_ = torch::nn::Linear(netwidth_ + input_ch_views, netwidth_ / 2); + viewdirs_net_fine_ = torch::nn::Linear(netwidth_fine_ + input_ch_views, netwidth_fine_ / 2); + register_module("viewdirs_net", viewdirs_net_); + register_module("viewdirs_net_fine", viewdirs_net_fine_); + } +} + +torch::Tensor NeRFModel::embed_fn(const torch::Tensor& inputs) { + auto x = inputs; + std::vector embeds; + embeds.push_back(x); + + for (int i = 0; i < multires_; i++) { + for (int j = 0; j < 3; j++) { + float freq = std::pow(2.0f, i); + float phase = (j % 2) * M_PI / 2; + embeds.push_back(torch::sin(freq * x.index({torch::indexing::Slice(), j}) + phase)); + embeds.push_back(torch::cos(freq * x.index({torch::indexing::Slice(), j}) + phase)); + } + } + + return torch::cat(embeds, -1); +} + +torch::Tensor NeRFModel::embeddirs_fn(const torch::Tensor& inputs) { + auto x = inputs; + std::vector embeds; + embeds.push_back(x); + + for (int i = 0; i < multires_views_; i++) { + for (int j = 0; j < 3; j++) { + float freq = std::pow(2.0f, i); + float phase = (j % 2) * M_PI / 2; + embeds.push_back(torch::sin(freq * x.index({torch::indexing::Slice(), j}) + phase)); + embeds.push_back(torch::cos(freq * x.index({torch::indexing::Slice(), j}) + phase)); + } + } + + return torch::cat(embeds, -1); +} + +std::tuple NeRFModel::forward( + const torch::Tensor& inputs_flat, + const torch::Tensor& viewdirs, + bool is_fine) { + + auto inputs_embedded = embed_fn(inputs_flat); + auto net = is_fine ? net_fine_ : net_; + auto x = net->forward(inputs_embedded); + + if (use_viewdirs_ && viewdirs.numel() > 0) { + auto viewdirs_embedded = embeddirs_fn(viewdirs); + auto viewdirs_net = is_fine ? viewdirs_net_fine_ : viewdirs_net_; + x = torch::cat({x, viewdirs_embedded}, -1); + x = viewdirs_net->forward(x); + } + + auto rgb = torch::sigmoid(x.index({torch::indexing::Slice(), torch::indexing::Slice(0, 3)})); + auto alpha = torch::sigmoid(x.index({torch::indexing::Slice(), 3})); + auto raw = x; + + return std::make_tuple(rgb, alpha, raw); +} + +std::tuple NeRFModel::get_outputs( + const torch::Tensor& inputs_flat, + const torch::Tensor& viewdirs, + bool is_fine) { + return forward(inputs_flat, viewdirs, is_fine); +} + +} // namespace nerf \ No newline at end of file diff --git a/src/renderer.cpp b/src/renderer.cpp new file mode 100644 index 000000000..fee4ca3a0 --- /dev/null +++ b/src/renderer.cpp @@ -0,0 +1,155 @@ +#include "nerf/renderer.hpp" +#include + +namespace nerf { + +Renderer::Renderer(std::shared_ptr model, + int N_samples, + int N_importance, + bool use_viewdirs, + float raw_noise_std, + bool white_bkgd) + : model_(model), + N_samples_(N_samples), + N_importance_(N_importance), + use_viewdirs_(use_viewdirs), + raw_noise_std_(raw_noise_std), + white_bkgd_(white_bkgd) {} + +std::tuple Renderer::render_rays( + const torch::Tensor& rays_o, + const torch::Tensor& rays_d, + const torch::Tensor& viewdirs, + const torch::Tensor& near, + const torch::Tensor& far, + bool is_fine) { + + // Get number of rays + int N_rays = rays_o.size(0); + + // Sample points along rays + auto t_vals = torch::linspace(0, 1, N_samples_); + auto z_vals = near.expand({N_rays, N_samples_}) * (1 - t_vals) + + far.expand({N_rays, N_samples_}) * t_vals; + + // Add noise to z_vals + if (raw_noise_std_ > 0) { + auto noise = torch::randn({N_rays, N_samples_}) * raw_noise_std_; + z_vals = z_vals + noise; + } + + // Get points along rays + auto pts = rays_o.unsqueeze(1) + rays_d.unsqueeze(1) * z_vals.unsqueeze(-1); + + // Flatten points and viewdirs + auto pts_flat = pts.reshape({-1, 3}); + auto viewdirs_flat = viewdirs.unsqueeze(1).expand({N_rays, N_samples_, 3}) + .reshape({-1, 3}); + + // Get model outputs + auto [rgb, alpha, raw] = model_->forward(pts_flat, viewdirs_flat, is_fine); + + // Reshape outputs + rgb = rgb.reshape({N_rays, N_samples_, 3}); + alpha = alpha.reshape({N_rays, N_samples_}); + + // Compute weights + auto weights = compute_accumulated_transmittance(alpha); + + // Compute final RGB + auto rgb_map = torch::sum(weights.unsqueeze(-1) * rgb, 1); + + // Compute depth map + auto depth_map = torch::sum(weights * z_vals, 1); + + // Compute accumulated transmittance + auto acc_map = torch::sum(weights, 1); + + // Add white background if needed + if (white_bkgd_) { + rgb_map = rgb_map + (1 - acc_map).unsqueeze(-1); + } + + return std::make_tuple(rgb_map, depth_map, acc_map, weights); +} + +torch::Tensor Renderer::render( + const torch::Tensor& H, + const torch::Tensor& W, + const torch::Tensor& K, + const torch::Tensor& c2w, + const torch::Tensor& near, + const torch::Tensor& far, + bool is_fine) { + + // Create pixel coordinates + auto i = torch::arange(H.item()); + auto j = torch::arange(W.item()); + auto ij = torch::meshgrid({i, j}); + auto dirs = torch::stack({ + (ij[1] - K[0][2].item()) / K[0][0].item(), + -(ij[0] - K[1][2].item()) / K[1][1].item(), + -torch::ones_like(ij[0]) + }, -1); + + // Rotate ray directions from camera frame to world frame + auto rays_d = torch::sum(dirs.unsqueeze(-2) * c2w.index({torch::indexing::Slice(0, 3), torch::indexing::Slice(0, 3)}), -1); + + // Normalize ray directions + rays_d = rays_d / torch::norm(rays_d, 2, -1, true); + + // Get ray origins + auto rays_o = c2w.index({torch::indexing::Slice(0, 3), 3}).expand(rays_d.sizes()); + + // Reshape for rendering + rays_o = rays_o.reshape({-1, 3}); + rays_d = rays_d.reshape({-1, 3}); + + // Render rays + auto [rgb_map, depth_map, acc_map, _] = render_rays(rays_o, rays_d, rays_d, near, far, is_fine); + + // Reshape outputs + rgb_map = rgb_map.reshape({H.item(), W.item(), 3}); + depth_map = depth_map.reshape({H.item(), W.item()}); + acc_map = acc_map.reshape({H.item(), W.item()}); + + return rgb_map; +} + +torch::Tensor Renderer::sample_pdf(const torch::Tensor& bins, const torch::Tensor& weights, int N_samples) { + // Normalize weights + weights = weights + 1e-5; + auto pdf = weights / torch::sum(weights, -1, true); + auto cdf = torch::cumsum(pdf, -1); + cdf = torch::cat({torch::zeros_like(cdf.index({torch::indexing::Slice(), torch::indexing::Slice(0, 1)})), cdf}, -1); + + // Take uniform samples + auto u = torch::rand({weights.size(0), N_samples}); + + // Invert CDF + auto inds = torch::searchsorted(cdf, u, true); + auto below = torch::max(torch::zeros_like(inds), inds - 1); + auto above = torch::min(torch::ones_like(inds) * (cdf.size(-1) - 1), inds); + auto inds_g = torch::stack({below, above}, -1); + + auto cdf_g = torch::gather(cdf, -1, inds_g); + auto bins_g = torch::gather(bins, -1, inds_g); + + auto denom = (cdf_g.index({torch::indexing::Slice(), torch::indexing::Slice(), 1}) - + cdf_g.index({torch::indexing::Slice(), torch::indexing::Slice(), 0})); + denom = torch::where(denom < 1e-5, torch::ones_like(denom), denom); + auto t = (u - cdf_g.index({torch::indexing::Slice(), torch::indexing::Slice(), 0})) / denom; + auto samples = bins_g.index({torch::indexing::Slice(), torch::indexing::Slice(), 0}) + + t * (bins_g.index({torch::indexing::Slice(), torch::indexing::Slice(), 1}) - + bins_g.index({torch::indexing::Slice(), torch::indexing::Slice(), 0})); + + return samples; +} + +torch::Tensor Renderer::compute_accumulated_transmittance(const torch::Tensor& alphas) { + auto transmittance = torch::cumprod(1 - alphas + 1e-10, -1); + return alphas * torch::cat({torch::ones_like(transmittance.index({torch::indexing::Slice(), torch::indexing::Slice(0, 1)})), + transmittance.index({torch::indexing::Slice(), torch::indexing::Slice(0, -1)})}, -1); +} + +} // namespace nerf \ No newline at end of file