diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..b6c1da7 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,26 @@ +# Ignore the following files and directories when building the Docker image +*.pyc +__pycache__/ +*.ipynb_checkpoints +*.log +*.csv +*.tsv +*.h5 +*.pth +*.pt +*.zip +*.tar.gz +*.egg-info/ +dist/ +build/ +.env +venv/ +.env.local +*.DS_Store +*.egg +*.whl +*.pkl +*.json +*.yaml +*.yml +submodules/ \ No newline at end of file diff --git a/.gitignore b/.gitignore index a05a2b7..c38a86d 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,6 @@ dmypy.json # Pyre type checker .pyre/ learnableearthparser/fast_sampler/_sampler.c + +# data +data/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..aa2d20a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,60 @@ +FROM nvidia/cuda:12.1.1-devel-ubuntu22.04 + +# Set the working directory +WORKDIR /EDGS + +# Install system dependencies first, including git, build-essential, and cmake +RUN apt-get update && apt-get install -y \ + git \ + wget \ + build-essential \ + cmake \ + ninja-build \ + libgl1-mesa-glx \ + libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + +# Copy only essential files for cloning submodules first (e.g., .gitmodules) +# Or, if submodules are public, you might not need to copy anything specific for this step +# For simplicity, we'll copy everything, but this could be optimized +COPY . . + +# Initialize and update submodules +RUN git submodule init && git submodule update --recursive + +# Install Miniconda +RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \ + bash /tmp/miniconda.sh -b -p /opt/conda && \ + rm /tmp/miniconda.sh +ENV PATH="/opt/conda/bin:${PATH}" + +# Create the conda environment and install dependencies +# Accept Anaconda TOS before using conda +RUN conda init bash && \ + conda config --set always_yes yes --set changeps1 no && \ + conda config --add channels defaults && \ + conda config --set channel_priority strict && \ + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && \ + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r +# Now you can safely create your environment +RUN conda create -y -n edgs python=3.10 pip && \ + conda clean -afy && \ + echo "source activate edgs" > ~/.bashrc + +# Set CUDA architectures to compile for +ENV TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;8.9;9.0+PTX" + +# Activate the environment and install Python dependencies +RUN /bin/bash -c "source activate edgs && \ + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 && \ + pip install -e ./submodules/gaussian-splatting/submodules/diff-gaussian-rasterization && \ + pip install -e ./submodules/gaussian-splatting/submodules/simple-knn && \ + pip install pycolmap wandb hydra-core tqdm torchmetrics lpips matplotlib rich plyfile imageio imageio-ffmpeg && \ + pip install -e ./submodules/RoMa && \ + pip install gradio plotly scikit-learn moviepy==2.1.1 ffmpeg open3d jupyterlab matplotlib" + +# Expose the port for Gradio +EXPOSE 7862 + +# Keep the container running in detached mode +CMD ["tail", "-f", "/dev/null"] \ No newline at end of file diff --git a/README.md b/README.md index 8c40427..c9fb61d 100644 --- a/README.md +++ b/README.md @@ -69,45 +69,14 @@ Alternatively, check our [Colab notebook](https://colab.research.google.com/gith ## 🛠️ Installation -You can either run `install.sh` or manually install using the following: +You can install it just: ```bash -git clone git@github.com:CompVis/EDGS.git --recursive -cd EDGS -git submodule update --init --recursive - -conda create -y -n edgs python=3.10 pip -conda activate edgs - -# Set up path to your CUDA. In our experience similar versions like 12.2 also work well -export CUDA_HOME=/usr/local/cuda-12.1 -export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH -export PATH=$CUDA_HOME/bin:$PATH - -conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y -conda install nvidia/label/cuda-12.1.0::cuda-toolkit -y - -pip install -e submodules/gaussian-splatting/submodules/diff-gaussian-rasterization -pip install -e submodules/gaussian-splatting/submodules/simple-knn - -# For COLMAP and pycolmap -# Optionally install original colmap but probably pycolmap suffices -# conda install conda-forge/label/colmap_dev::colmap -pip install pycolmap - - -pip install wandb hydra-core tqdm torchmetrics lpips matplotlib rich plyfile imageio imageio-ffmpeg -conda install numpy=1.26.4 -y -c conda-forge --override-channels - -pip install -e submodules/RoMa -conda install anaconda::jupyter --yes - -# Stuff necessary for gradio and visualizations -pip install gradio -pip install plotly scikit-learn moviepy==2.1.1 ffmpeg -pip install open3d +docker compose up -d ``` +or you can install with running `script/install.sh`. + ## 📦 Data @@ -118,6 +87,36 @@ We evaluated on the following datasets: ### Using Your Own Dataset +#### Option A +Use gradle demo. +After running `docker compose up -d`, +``` +docker compose exec edgs-app bash +python script/gradio_demo.py --port 7862 +``` + +#### Option B +From command line. +``` +docker compose exec edgs-app bash +python script/fit_model_to_scene_full.py --video_path [--processed_scenes_dir ] +``` + +#### Option C +Using Jupyter lab. +``` +docker compose exec edgs-app bash +``` +And in the terminal in the docker container, +``` +jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root --notebook-dir=notebooks +``` +After JupyterLab starts, it will print URLs to the terminal. Look for a URL containing a token, like: + `http://127.0.0.1:8888/lab?token=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx` +Open `http://localhost:8888` (or `http://127.0.0.1:8888`) in your host browser. +When prompted for a "Password or token", paste the `xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx` part from the URL in step 4 into the field and log in. Alternatively, you can paste the full URL from step 4 directly into your browser. + +#### Option D You can use the same data format as the [3DGS project](https://github.com/graphdeco-inria/gaussian-splatting?tab=readme-ov-file#processing-your-own-scenes). Please follow their guide to prepare your scene. Expected folder structure: @@ -134,6 +133,11 @@ scene_folder |---points3D.bin ``` +``` +docker compose exec edgs-app bash +``` +Then run training command as described below section. + Nerf synthetic format is also acceptable. You can also use functions provided in our code to convert a collection of images or a sinlge video into a desired format. However, this may requre tweaking and processing time can be large for large collection of images with little overlap. @@ -144,11 +148,11 @@ You can also use functions provided in our code to convert a collection of image To optimize on a single scene in COLMAP format use this code. ```bash -python train.py \ +python script/train.py \ train.gs_epochs=30000 \ train.no_densify=True \ gs.dataset.source_path= \ - gs.dataset.model_path= \ + gs.dataset.output_path= \ init_wC.matches_per_ref=20000 \ init_wC.nns_per_ref=3 \ init_wC.num_refs=180 @@ -162,7 +166,7 @@ python train.py \ Disables densification. True by default. * `gs.dataset.source_path` Path to your input dataset directory. This should follow the same format as the original 3DGS dataset structure. - * `gs.dataset.model_path` + * `gs.dataset.output_path` Output directory where the trained model, logs, and renderings will be saved. * `init_wC.matches_per_ref` Number of 2D feature correspondences to extract per reference view for initialization. More matches leads to more gaussians. diff --git a/configs/gs/base.yaml b/configs/gs/base.yaml index 5a74ddc..bc0501d 100644 --- a/configs/gs/base.yaml +++ b/configs/gs/base.yaml @@ -41,7 +41,8 @@ pipe: dataset: densify_until_iter: 15000 source_path: '' #path to dataset - model_path: '' #path to logs + output_path: '' #path to output directory + model_path: '${.output_path}' #alias for output_path (required by submodule) images: images resolution: -1 white_background: false diff --git a/configs/train.yaml b/configs/train.yaml index e40ec5b..6e776d0 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -1,22 +1,22 @@ defaults: - gs: base - - _self_ + - _self_ seed: 228 wandb: - mode: "online" # "disabled" for no logging + mode: "disabled" # "online" or "disabled" entity: "3dcorrespondence" project: "Adv3DGS" group: null name: null tag: "debug" - + train: - gs_epochs: 0 # number of 3dgs iterations - reduce_opacity: True + gs_epochs: 1000 # number of 3dgs iterations + reduce_opacity: True no_densify: False # if True, the model will not be densified - max_lr: True + max_lr: True load: gs: null #path to 3dgs checkpoint @@ -28,11 +28,9 @@ verbose: true init_wC: use: True # use EDGS matches_per_ref: 15_000 # number of matches per reference - num_refs: 180 # number of reference images + num_refs: 18 # number of reference images nns_per_ref: 3 # number of nearest neighbors per reference scaling_factor: 0.001 proj_err_tolerance: 0.01 roma_model: "outdoors" # you can change this to "indoors" or "outdoors" - add_SfM_init : False - - + add_SfM_init: False diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..26df04c --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,26 @@ +services: + edgs-app: + build: . # Instructs Docker Compose to build using the Dockerfile in the current directory + image: edgs-app # This is the name of the image you built + ports: + - "7862:7862" # For Gradio, if you still use it + - "8888:8888" # Map port 8888 for JupyterLab + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all # Use all available GPUs + capabilities: [gpu] # Request GPU capabilities + environment: + - NVIDIA_VISIBLE_DEVICES=all + volumes: + - ./data:/EDGS/data # Example: map a local 'data' folder to '/EDGS/data' in the container + - ./outputs:/EDGS/outputs # Example: map a local 'output' folder + - ./script:/EDGS/script # Example: map a local 'scripts' folder + - ./source:/EDGS/source # Example: map a local 'sources' folder + - ./configs:/EDGS/configs # Example: map a local 'config' folder + - ./notebooks:/EDGS/notebooks # Map a local 'notebooks' folder for JupyterLab + - ./model:/EDGS/model # Map model directory for persistent model caching + stdin_open: true # Keep STDIN open for interactive processes + tty: true # Allocate a TTY \ No newline at end of file diff --git a/notebooks/fit_model_to_scene_full.ipynb b/notebooks/fit_model_to_scene_full.ipynb index 09e6323..1d0b6c6 100644 --- a/notebooks/fit_model_to_scene_full.ipynb +++ b/notebooks/fit_model_to_scene_full.ipynb @@ -67,10 +67,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "59c64632-e31a-4ead-98a5-4ab0f295e54d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "xFormers not available\n" + ] + } + ], "source": [ "import torch\n", "import numpy as np\n", @@ -96,10 +104,96 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "925adfa3-c311-44b6-a8c4-a31fb7426947", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "gs:\n", + " _target_: source.networks.Warper3DGS\n", + " verbose: true\n", + " viewpoint_stack: null\n", + " sh_degree: 3\n", + " opt:\n", + " iterations: 30000\n", + " position_lr_init: 0.00016\n", + " position_lr_final: 1.6e-06\n", + " position_lr_delay_mult: 0.01\n", + " position_lr_max_steps: 30000\n", + " feature_lr: 0.0025\n", + " opacity_lr: 0.025\n", + " scaling_lr: 0.005\n", + " rotation_lr: 0.001\n", + " percent_dense: 0.01\n", + " lambda_dssim: 0.2\n", + " densification_interval: 100\n", + " opacity_reset_interval: 30000\n", + " densify_from_iter: 500\n", + " densify_until_iter: 15000\n", + " densify_grad_threshold: 0.0002\n", + " random_background: false\n", + " save_iterations:\n", + " - 3000\n", + " - 7000\n", + " - 15000\n", + " - 30000\n", + " batch_size: 64\n", + " exposure_lr_init: 0.01\n", + " exposure_lr_final: 0.0001\n", + " exposure_lr_delay_steps: 0\n", + " exposure_lr_delay_mult: 0.0\n", + " TRAIN_CAM_IDX_TO_LOG: 50\n", + " TEST_CAM_IDX_TO_LOG: 10\n", + " pipe:\n", + " convert_SHs_python: false\n", + " compute_cov3D_python: false\n", + " debug: false\n", + " antialiasing: false\n", + " dataset:\n", + " densify_until_iter: 15000\n", + " source_path: ''\n", + " model_path: ''\n", + " images: images\n", + " resolution: -1\n", + " white_background: false\n", + " data_device: cuda\n", + " eval: false\n", + " depths: ''\n", + " train_test_exp: false\n", + "seed: 228\n", + "wandb:\n", + " mode: online\n", + " entity: m-ogawa-sensyn\n", + " project: Adv3DGS\n", + " group: null\n", + " name: null\n", + " tag: debug\n", + "train:\n", + " gs_epochs: 0\n", + " reduce_opacity: true\n", + " no_densify: false\n", + " max_lr: true\n", + "load:\n", + " gs: null\n", + " gs_step: null\n", + "device: cuda:0\n", + "verbose: true\n", + "init_wC:\n", + " use: true\n", + " matches_per_ref: 15000\n", + " num_refs: 180\n", + " nns_per_ref: 3\n", + " scaling_factor: 0.001\n", + " proj_err_tolerance: 0.01\n", + " roma_model: outdoors\n", + " add_SfM_init: false\n", + "\n" + ] + } + ], "source": [ "with initialize(config_path=\"../configs\", version_base=\"1.1\"):\n", " cfg = compose(config_name=\"train\")\n", @@ -124,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "07e4ca51", "metadata": {}, "outputs": [], @@ -177,10 +271,40 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "2056ee6f-dbb6-4ce8-86f0-5b4f9721d093", - "metadata": {}, - "outputs": [], + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output folder: ./scene_edgsed/\n" + ] + }, + { + "ename": "InstantiationException", + "evalue": "Error in call to target 'source.networks.Warper3DGS':\nAssertionError('Could not recognize scene type!')\nfull_key: gs", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m/opt/conda/envs/edgs/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:92\u001b[0m, in \u001b[0;36m_call_target\u001b[0;34m(_target_, _partial_, args, kwargs, full_key)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 92\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_target_\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n", + "File \u001b[0;32m/EDGS/notebooks/../source/networks.py:26\u001b[0m, in \u001b[0;36mWarper3DGS.__init__\u001b[0;34m(self, sh_degree, opt, pipe, dataset, viewpoint_stack, verbose, do_train_test_split)\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpipe \u001b[38;5;241m=\u001b[39m pipe\n\u001b[0;32m---> 26\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscene \u001b[38;5;241m=\u001b[39m \u001b[43mScene\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgaussians\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshuffle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m do_train_test_split:\n", + "File \u001b[0;32m/EDGS/notebooks/../submodules/gaussian-splatting/scene/__init__.py:49\u001b[0m, in \u001b[0;36mScene.__init__\u001b[0;34m(self, args, gaussians, load_iteration, shuffle, resolution_scales)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 49\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCould not recognize scene type!\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloaded_iter:\n", + "\u001b[0;31mAssertionError\u001b[0m: Could not recognize scene type!", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[0;31mInstantiationException\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 12\u001b[0m\n\u001b[1;32m 10\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(cfg\u001b[38;5;241m.\u001b[39mgs\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39mmodel_path, exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m# Init gs model\u001b[39;00m\n\u001b[0;32m---> 12\u001b[0m gs \u001b[38;5;241m=\u001b[39m \u001b[43mhydra\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mutils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minstantiate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgs\u001b[49m\u001b[43m)\u001b[49m \n\u001b[1;32m 13\u001b[0m trainer \u001b[38;5;241m=\u001b[39m EDGSTrainer(GS\u001b[38;5;241m=\u001b[39mgs,\n\u001b[1;32m 14\u001b[0m training_config\u001b[38;5;241m=\u001b[39mcfg\u001b[38;5;241m.\u001b[39mgs\u001b[38;5;241m.\u001b[39mopt,\n\u001b[1;32m 15\u001b[0m device\u001b[38;5;241m=\u001b[39mcfg\u001b[38;5;241m.\u001b[39mdevice)\n", + "File \u001b[0;32m/opt/conda/envs/edgs/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:226\u001b[0m, in \u001b[0;36minstantiate\u001b[0;34m(config, *args, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m _convert_ \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39mpop(_Keys\u001b[38;5;241m.\u001b[39mCONVERT, ConvertMode\u001b[38;5;241m.\u001b[39mNONE)\n\u001b[1;32m 224\u001b[0m _partial_ \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39mpop(_Keys\u001b[38;5;241m.\u001b[39mPARTIAL, \u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m--> 226\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minstantiate_node\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrecursive\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_recursive_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconvert\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_convert_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpartial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_partial_\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m OmegaConf\u001b[38;5;241m.\u001b[39mis_list(config):\n\u001b[1;32m 230\u001b[0m \u001b[38;5;66;03m# Finalize config (convert targets to strings, merge with kwargs)\u001b[39;00m\n\u001b[1;32m 231\u001b[0m config_copy \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config)\n", + "File \u001b[0;32m/opt/conda/envs/edgs/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:347\u001b[0m, in \u001b[0;36minstantiate_node\u001b[0;34m(node, convert, recursive, partial, *args)\u001b[0m\n\u001b[1;32m 342\u001b[0m value \u001b[38;5;241m=\u001b[39m instantiate_node(\n\u001b[1;32m 343\u001b[0m value, convert\u001b[38;5;241m=\u001b[39mconvert, recursive\u001b[38;5;241m=\u001b[39mrecursive\n\u001b[1;32m 344\u001b[0m )\n\u001b[1;32m 345\u001b[0m kwargs[key] \u001b[38;5;241m=\u001b[39m _convert_node(value, convert)\n\u001b[0;32m--> 347\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_call_target\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_target_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpartial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfull_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 348\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 349\u001b[0m \u001b[38;5;66;03m# If ALL or PARTIAL non structured or OBJECT non structured,\u001b[39;00m\n\u001b[1;32m 350\u001b[0m \u001b[38;5;66;03m# instantiate in dict and resolve interpolations eagerly.\u001b[39;00m\n\u001b[1;32m 351\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m convert \u001b[38;5;241m==\u001b[39m ConvertMode\u001b[38;5;241m.\u001b[39mALL \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m 352\u001b[0m convert \u001b[38;5;129;01min\u001b[39;00m (ConvertMode\u001b[38;5;241m.\u001b[39mPARTIAL, ConvertMode\u001b[38;5;241m.\u001b[39mOBJECT)\n\u001b[1;32m 353\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m node\u001b[38;5;241m.\u001b[39m_metadata\u001b[38;5;241m.\u001b[39mobject_type \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28mdict\u001b[39m)\n\u001b[1;32m 354\u001b[0m ):\n", + "File \u001b[0;32m/opt/conda/envs/edgs/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:97\u001b[0m, in \u001b[0;36m_call_target\u001b[0;34m(_target_, _partial_, args, kwargs, full_key)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m full_key:\n\u001b[1;32m 96\u001b[0m msg \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mfull_key: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfull_key\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 97\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m InstantiationException(msg) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01me\u001b[39;00m\n", + "\u001b[0;31mInstantiationException\u001b[0m: Error in call to target 'source.networks.Warper3DGS':\nAssertionError('Could not recognize scene type!')\nfull_key: gs" + ] + } + ], "source": [ "_ = wandb.init(entity=cfg.wandb.entity,\n", " project=cfg.wandb.project,\n", @@ -1781,7 +1905,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.21" + "version": "3.10.18" } }, "nbformat": 4, diff --git a/script/fit_model_to_scene_full.py b/script/fit_model_to_scene_full.py new file mode 100644 index 0000000..851797b --- /dev/null +++ b/script/fit_model_to_scene_full.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python +# coding: utf-8 + +# # EDGS: Eliminating Densification for Gaussian Splatting +# EDGS improves 3D Gaussian Splatting by removing the need for densification. It starts from a dense point cloud initialization based on 2D correspondences, leading to: +# - ⚡ Faster convergence (only 25% of training time) +# - 🌀 Higher rendering quality +# - 💡 No need for progressive densification + +# ## 2. Import libraries +import argparse +import logging +import os +import random +import sys + +import hydra +import numpy as np +import omegaconf +import torch +import wandb +from hydra import compose, initialize +from matplotlib import pyplot as plt +from omegaconf import OmegaConf + +# Add the project root directory to sys.path +# so that modules from 'source' can be imported. +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) +# sys.path.append("../submodules/gaussian-splatting") +from source.trainer import EDGSTrainer +from source.utils_aux import set_seed +from source.utils_preprocess import ( + orchestrate_video_to_colmap_scene, # Use the refactored function +) + +# Initialize logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) + +# --- Add argument parsing --- +parser = argparse.ArgumentParser( + description="Fit EDGS model to a scene, optionally from a video." +) +parser.add_argument( + "--video_path", + type=str, + default=os.path.join( + project_root, "assets", "examples", "video_fruits.mp4" + ), # Use project_root + help="Path to the input video file.", +) +parser.add_argument( + "--outputs_dir", + type=str, + default=os.path.join(project_root, "outputs"), # Use project_root + help="Base directory where processed COLMAP scenes will be stored.", +) +args = parser.parse_args() +# --- End argument parsing --- + +with initialize(config_path="../configs", version_base="1.1"): + cfg = compose(config_name="train") + +SAME_WITH_GRADIO_DEMO = True +if SAME_WITH_GRADIO_DEMO: + cfg.gs.opt.opacity_reset_interval = 1_000_000 + cfg.train.reduce_opacity = True + cfg.train.no_densify = True + cfg.train.max_lr = True + cfg.train.gs_epochs = 1000 + + cfg.init_wC.use = True + cfg.init_wC.nns_per_ref = 1 + cfg.init_wC.add_SfM_init = False + cfg.init_wC.scaling_factor = 0.00077 * 2.0 + cfg.init_wC.num_refs = 16 + cfg.init_wC.matches_per_ref = 20000 + +print(OmegaConf.to_yaml(cfg)) + + +# # 3. Init input parameters + +# ## 3.1 Optionally preprocess video +# process the input video +if os.path.exists(args.video_path): + print(f"Starting video processing for: {args.video_path}") + try: + # The first return value 'images_data' might not be directly used by the trainer + # if the Scene object loads everything from the COLMAP directory. + _, scene_dir = orchestrate_video_to_colmap_scene( + args.video_path, + cfg.init_wC.num_refs, # Assuming you added this arg + max_size=1024, # Or make it an arg + base_work_dir=args.outputs_dir, # Assuming you added this arg + ) + if scene_dir is None: + print(f"Failed to process video {args.video_path}. Exiting.") + sys.exit(1) + cfg.gs.dataset.source_path = scene_dir + cfg.gs.dataset.output_path = os.path.join(scene_dir, "models") + print(f"Set output_path to: {cfg.gs.dataset.output_path}") + os.makedirs(cfg.gs.dataset.output_path, exist_ok=True) + except Exception as e: + print(f"Error during video preprocessing: {e}") + sys.exit(1) + + +# # 4. Initilize model and logger +if cfg.wandb.mode != "disabled": + logging.info( + "wandb logging is enabled (mode={}). Results will be logged to wandb.".format( + cfg.wandb.mode + ) + ) + _ = wandb.init( + entity=cfg.wandb.entity, + project=cfg.wandb.project, + config=omegaconf.OmegaConf.to_container( + cfg, resolve=True, throw_on_missing=True + ), + name=cfg.wandb.name, + mode=cfg.wandb.mode, + ) +else: + logging.info( + "wandb logging is disabled (mode={}). Results will not be logged to wandb.".format( + cfg.wandb.mode + ) + ) +omegaconf.OmegaConf.resolve(cfg) +set_seed(cfg.seed) +# Init output folder +print("Output folder: {}".format(cfg.gs.dataset.output_path)) +os.makedirs(cfg.gs.dataset.output_path, exist_ok=True) +# Init gs model +gs = hydra.utils.instantiate(cfg.gs) +trainer = EDGSTrainer( + GS=gs, + training_config=cfg.gs.opt, + device=cfg.device, + log_wandb=(cfg.wandb.mode != "disabled"), +) + + +# # 5. Init with matchings +trainer.timer.start() +trainer.init_with_corr(cfg.init_wC) +trainer.timer.pause() + + +# ### Visualize a few initial viewpoints +with torch.no_grad(): + viewpoint_stack = trainer.GS.scene.getTrainCameras() + viewpoint_cams_to_viz = random.sample(trainer.GS.scene.getTrainCameras(), 4) + for idx, viewpoint_cam in enumerate(viewpoint_cams_to_viz): + render_pkg = trainer.GS(viewpoint_cam) + image = render_pkg["render"] + + image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0) + image_gt_np = ( + viewpoint_cam.original_image.clone() + .detach() + .cpu() + .numpy() + .transpose(1, 2, 0) + ) + + # Clip values to be in the range [0, 1] + image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8) + image_gt_np = np.clip(image_gt_np * 255, 0, 255).astype(np.uint8) + + fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6)) + ax[0].imshow(image_gt_np) + ax[0].axis("off") + ax[1].imshow(image_np) + ax[1].axis("off") + plt.tight_layout() + plt.savefig( + os.path.join( + cfg.gs.dataset.output_path, + f"viewpoint_{idx}_initial.png", + ) + ) + plt.show() + plt.close(fig) + + +# # 6.Optimize scene +# Optimize first briefly for 5k steps and visualize results. We also disable saving of pretrained models. Train function can be changed for any other method +trainer.saving_iterations = [] +# cfg.train.gs_epochs = 5_000 +trainer.train(cfg.train) + + +# ### Visualize same viewpoints +with torch.no_grad(): + for viewpoint_cam in viewpoint_cams_to_viz: + render_pkg = trainer.GS(viewpoint_cam) + image = render_pkg["render"] + + image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0) + image_gt_np = ( + viewpoint_cam.original_image.clone() + .detach() + .cpu() + .numpy() + .transpose(1, 2, 0) + ) + + # Clip values to be in the range [0, 1] + image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8) + image_gt_np = np.clip(image_gt_np * 255, 0, 255).astype(np.uint8) + + fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6)) + ax[0].imshow(image_gt_np) + ax[0].axis("off") + ax[1].imshow(image_np) + ax[1].axis("off") + plt.tight_layout() + plt.show() + + +# ### Save model +with torch.no_grad(): + trainer.save_model() + + +# # # 7. Continue training until we reach total 30K training steps +# cfg.train.gs_epochs = 25_000 +# trainer.train(cfg.train) + + +# # ### Visualize same viewpoints +# with torch.no_grad(): +# for viewpoint_cam in viewpoint_cams_to_viz: +# render_pkg = trainer.GS(viewpoint_cam) +# image = render_pkg["render"] + +# image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0) +# image_gt_np = ( +# viewpoint_cam.original_image.clone() +# .detach() +# .cpu() +# .numpy() +# .transpose(1, 2, 0) +# ) + +# # Clip values to be in the range [0, 1] +# image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8) +# image_gt_np = np.clip(image_gt_np * 255, 0, 255).astype(np.uint8) + +# fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6)) +# ax[0].imshow(image_gt_np) +# ax[0].axis("off") +# ax[1].imshow(image_np) +# ax[1].axis("off") +# plt.tight_layout() +# plt.show() + + +# ### Save model +# with torch.no_grad(): +# trainer.save_model() diff --git a/full_eval.py b/script/full_eval.py similarity index 81% rename from full_eval.py rename to script/full_eval.py index 5342566..60f7dc7 100644 --- a/full_eval.py +++ b/script/full_eval.py @@ -42,19 +42,19 @@ for scene in mipnerf360_outdoor_scenes: source = args.mipnerf360 + "/" + scene experiment = name + scene - os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.model_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=25_000 init_wC.nns_per_ref=3 gs.dataset.images=images_4 init_wC.num_refs=180 train.no_densify=True") + os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.output_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=25_000 init_wC.nns_per_ref=3 gs.dataset.images=images_4 init_wC.num_refs=180 train.no_densify=True") for scene in mipnerf360_indoor_scenes: source = args.mipnerf360 + "/" + scene experiment = name + scene - os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.model_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=25_000 init_wC.nns_per_ref=3 gs.dataset.images=images_2 init_wC.num_refs=180 train.no_densify=True") + os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.output_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=25_000 init_wC.nns_per_ref=3 gs.dataset.images=images_2 init_wC.num_refs=180 train.no_densify=True") for scene in tanks_and_temples_scenes: source = args.tanksandtemples + "/" + scene experiment = name + scene +"_tandt" - os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.model_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=15_000 init_wC.nns_per_ref=3 init_wC.num_refs=180 train.no_densify=True") + os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.output_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=15_000 init_wC.nns_per_ref=3 init_wC.num_refs=180 train.no_densify=True") for scene in deep_blending_scenes: source = args.deepblending + "/" + scene experiment = name + scene + "_db" - os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.model_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=15_000 init_wC.nns_per_ref=3 init_wC.num_refs=180 train.no_densify=True") + os.system(f"python train.py verbose=True gs.dataset.source_path={source} gs.dataset.output_path={args.output_path}/mipnerf/{scene} wandb.name={experiment} init_wC.use=True train.gs_epochs=30000 init_wC.matches_per_ref=15_000 init_wC.nns_per_ref=3 init_wC.num_refs=180 train.no_densify=True") if not args.skip_rendering: diff --git a/gradio_demo.py b/script/gradio_demo.py similarity index 51% rename from gradio_demo.py rename to script/gradio_demo.py index 9c55e44..ec6abbc 100644 --- a/gradio_demo.py +++ b/script/gradio_demo.py @@ -1,29 +1,42 @@ -import torch +import argparse +import contextlib +import io import os import shutil -import tempfile -import argparse -import gradio as gr import sys -import io -from PIL import Image -import numpy as np -from source.utils_aux import set_seed -from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene -from source.trainer import EDGSTrainer -from hydra import initialize, compose -import hydra +import tempfile import time -from source.visualization import generate_circular_camera_path, save_numpy_frames_as_mp4, generate_fully_smooth_cameras_with_tsp, put_text_on_image -import contextlib -import base64 +import gradio as gr +import hydra +import numpy as np +import torch +from hydra import compose, initialize -# Init RoMA model: -sys.path.append('../submodules/RoMa') -from romatch import roma_outdoor, roma_indoor +# Add the project root directory to sys.path +# so that modules from 'source' can be imported. +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) -roma_model = roma_indoor(device="cuda:0") +from source.trainer import EDGSTrainer +from source.utils_aux import set_seed +from source.utils_preprocess import ( + orchestrate_video_to_colmap_scene, # Import the new/refactored function + process_input_for_colmap, + run_colmap_on_scene, +) +from source.visualization import ( + generate_circular_camera_path, + generate_fully_smooth_cameras_with_tsp, + put_text_on_image, + save_numpy_frames_as_mp4, +) + +# Init RoMA model with custom cache: +from source.model_utils import load_roma_model_with_custom_cache + +roma_model = load_roma_model_with_custom_cache(model_type="indoor", device="cuda:0") roma_model.upsample_preds = False roma_model.symmetric = False @@ -33,6 +46,7 @@ trainer = None + class Tee(io.TextIOBase): def __init__(self, *streams): self.streams = streams @@ -46,6 +60,7 @@ def flush(self): for stream in self.streams: stream.flush() + def capture_logs(func, *args, **kwargs): log_capture_string = io.StringIO() tee = Tee(sys.__stdout__, log_capture_string) @@ -53,27 +68,30 @@ def capture_logs(func, *args, **kwargs): result = func(*args, **kwargs) return result, log_capture_string.getvalue() + # Training Pipeline -def run_training_pipeline(scene_dir, - num_ref_views=16, - num_corrs_per_view=20000, - num_steps=1_000, - mode_toggle="Ours (EDGS)"): - with initialize(config_path="./configs", version_base="1.1"): +def run_training_pipeline( + scene_dir, + num_ref_views=16, + num_corrs_per_view=20000, + num_steps=1_000, + mode_toggle="Ours (EDGS)", +): + with initialize(config_path="../configs", version_base="1.1"): cfg = compose(config_name="train") scene_name = os.path.basename(scene_dir) - model_output_dir = f"./outputs/{scene_name}_trained" + model_output_dir = f"../outputs/{scene_name}_trained" cfg.wandb.mode = "disabled" - cfg.gs.dataset.model_path = model_output_dir + cfg.gs.dataset.output_path = model_output_dir cfg.gs.dataset.source_path = scene_dir cfg.gs.dataset.images = "images" cfg.gs.opt.TEST_CAM_IDX_TO_LOG = 12 cfg.train.gs_epochs = 30000 - - if mode_toggle=="Ours (EDGS)": + + if mode_toggle == "Ours (EDGS)": cfg.gs.opt.opacity_reset_interval = 1_000_000 cfg.train.reduce_opacity = True cfg.train.no_densify = True @@ -84,15 +102,20 @@ def run_training_pipeline(scene_dir, cfg.init_wC.nns_per_ref = 1 cfg.init_wC.num_refs = num_ref_views cfg.init_wC.add_SfM_init = False - cfg.init_wC.scaling_factor = 0.00077 * 2. - + cfg.init_wC.scaling_factor = 0.00077 * 2.0 + set_seed(cfg.seed) - os.makedirs(cfg.gs.dataset.model_path, exist_ok=True) + os.makedirs(cfg.gs.dataset.output_path, exist_ok=True) global trainer global MODEL_PATH generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False) - trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=cfg.device, log_wandb=cfg.wandb.mode != 'disabled') + trainer = EDGSTrainer( + GS=generator3dgs, + training_config=cfg.gs.opt, + device=cfg.device, + log_wandb=cfg.wandb.mode != "disabled", + ) # Disable evaluation and saving trainer.saving_iterations = [] @@ -102,13 +125,15 @@ def run_training_pipeline(scene_dir, trainer.timer.start() start_time = time.time() trainer.init_with_corr(cfg.init_wC, roma_model=roma_model) - time_for_init = time.time()-start_time + time_for_init = time.time() - start_time viewpoint_cams = trainer.GS.scene.getTrainCameras() - path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams, - n_selected=6, # 8 - n_points_per_segment=30, # 30 - closed=False) + path_cameras = generate_fully_smooth_cameras_with_tsp( + existing_cameras=viewpoint_cams, + n_selected=6, # 8 + n_points_per_segment=30, # 30 + closed=False, + ) path_cameras = path_cameras + path_cameras[::-1] path_renderings = [] @@ -122,13 +147,24 @@ def run_training_pipeline(scene_dir, image = render_pkg["render"] image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) image_np = (image_np * 255).astype(np.uint8) - path_renderings.append(put_text_on_image(img=image_np, - text=f"Init stage.\nTime:{time_for_init:.3f}s. ")) - path_renderings = path_renderings + [put_text_on_image(img=image_np, text=f"Start fitting.\nTime:{time_for_init:.3f}s. ")]*30 - + path_renderings.append( + put_text_on_image( + img=image_np, text=f"Init stage.\nTime:{time_for_init:.3f}s. " + ) + ) + path_renderings = ( + path_renderings + + [ + put_text_on_image( + img=image_np, text=f"Start fitting.\nTime:{time_for_init:.3f}s. " + ) + ] + * 30 + ) + # Train and save visualizations during training. start_time = time.time() - for _ in range(int(num_steps//10)): + for _ in range(int(num_steps // 10)): with torch.no_grad(): viewpoint_cam = path_cameras[idx] idx = (idx + 1) % len(path_cameras) @@ -136,20 +172,27 @@ def run_training_pipeline(scene_dir, image = render_pkg["render"] image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) image_np = (image_np * 255).astype(np.uint8) - path_renderings.append(put_text_on_image( - img=image_np, - text=f"Fitting stage.\nTime:{time_for_init + time.time()-start_time:.3f}s. ")) - + path_renderings.append( + put_text_on_image( + img=image_np, + text=f"Fitting stage.\nTime:{time_for_init + time.time() - start_time:.3f}s. ", + ) + ) + cfg.train.gs_epochs = 10 trainer.train(cfg.train) - print(f"Time elapsed: {(time_for_init + time.time()-start_time):.2f}s.") + print(f"Time elapsed: {(time_for_init + time.time() - start_time):.2f}s.") # if (cfg.init_wC.use == False) and (time_for_init + time.time()-start_time) > 60: # break final_time = time.time() - + # Add static frame. To highlight we're done - path_renderings += [put_text_on_image( - img=image_np, text=f"Done.\nTime:{time_for_init + final_time -start_time:.3f}s. ")]*30 + path_renderings += [ + put_text_on_image( + img=image_np, + text=f"Done.\nTime:{time_for_init + final_time - start_time:.3f}s. ", + ) + ] * 30 # Final rendering at the end. for _ in range(len(path_cameras)): with torch.no_grad(): @@ -159,37 +202,61 @@ def run_training_pipeline(scene_dir, image = render_pkg["render"] image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) image_np = (image_np * 255).astype(np.uint8) - path_renderings.append(put_text_on_image(img=image_np, - text=f"Final result.\nTime:{time_for_init + final_time -start_time:.3f}s. ")) + path_renderings.append( + put_text_on_image( + img=image_np, + text=f"Final result.\nTime:{time_for_init + final_time - start_time:.3f}s. ", + ) + ) trainer.save_model() - final_video_path = os.path.join(STATIC_FILE_SERVING_FOLDER, f"{scene_name}_final.mp4") - save_numpy_frames_as_mp4(frames=path_renderings, output_path=final_video_path, fps=30, center_crop=0.85) - MODEL_PATH = cfg.gs.dataset.model_path - ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply") - shutil.copy(ply_path, os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply")) + final_video_path = os.path.join( + STATIC_FILE_SERVING_FOLDER, f"{scene_name}_final.mp4" + ) + save_numpy_frames_as_mp4( + frames=path_renderings, output_path=final_video_path, fps=30, center_crop=0.85 + ) + MODEL_PATH = cfg.gs.dataset.output_path + original_ply_path = os.path.join( # Renamed for clarity + cfg.gs.dataset.output_path, + f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply", + ) + # This is the path to the copied file in an allowed directory + copied_ply_path_for_serving = os.path.join( + STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply" + ) + shutil.copy(original_ply_path, copied_ply_path_for_serving) + + return ( + final_video_path, + copied_ply_path_for_serving, + ) # Return the path to the copied .ply file - return final_video_path, ply_path # Gradio Interface def gradio_interface(input_path, num_ref_views, num_corrs, num_steps): - images, scene_dir = run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024) - shutil.copytree(scene_dir, STATIC_FILE_SERVING_FOLDER+'/scene_colmaped', dirs_exist_ok=True) - (final_video_path, ply_path), log_output = capture_logs(run_training_pipeline, - scene_dir, - num_ref_views, - num_corrs, - num_steps) + images, scene_dir = run_full_pipeline( + input_path, num_ref_views, num_corrs, max_size=1024 + ) + shutil.copytree( + scene_dir, STATIC_FILE_SERVING_FOLDER + "/scene_colmaped", dirs_exist_ok=True + ) + (final_video_path, ply_path), log_output = capture_logs( + run_training_pipeline, scene_dir, num_ref_views, num_corrs, num_steps + ) images_rgb = [img[:, :, ::-1] for img in images] return images_rgb, final_video_path, scene_dir, ply_path, log_output + # Dummy Render Functions def render_all_views(scene_dir): viewpoint_cams = trainer.GS.scene.getTrainCameras() - path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams, - n_selected=8, - n_points_per_segment=60, - closed=False) + path_cameras = generate_fully_smooth_cameras_with_tsp( + existing_cameras=viewpoint_cams, + n_selected=8, + n_points_per_segment=60, + closed=False, + ) path_cameras = path_cameras + path_cameras[::-1] path_renderings = [] @@ -200,19 +267,21 @@ def render_all_views(scene_dir): image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) image_np = (image_np * 255).astype(np.uint8) path_renderings.append(image_np) - save_numpy_frames_as_mp4(frames=path_renderings, - output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4"), - fps=30, - center_crop=0.85) - + save_numpy_frames_as_mp4( + frames=path_renderings, + output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4"), + fps=30, + center_crop=0.85, + ) + return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4") + def render_circular_path(scene_dir): viewpoint_cams = trainer.GS.scene.getTrainCameras() - path_cameras = generate_circular_camera_path(existing_cameras=viewpoint_cams, - N=240, - radius_scale=0.65, - d=0) + path_cameras = generate_circular_camera_path( + existing_cameras=viewpoint_cams, N=240, radius_scale=0.65, d=0 + ) path_renderings = [] with torch.no_grad(): @@ -222,78 +291,63 @@ def render_circular_path(scene_dir): image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) image_np = (image_np * 255).astype(np.uint8) path_renderings.append(image_np) - save_numpy_frames_as_mp4(frames=path_renderings, - output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4"), - fps=30, - center_crop=0.85) - + save_numpy_frames_as_mp4( + frames=path_renderings, + output_path=os.path.join( + STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4" + ), + fps=30, + center_crop=0.85, + ) + return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4") + # Download Functions def download_cameras(): path = os.path.join(MODEL_PATH, "cameras.json") return f"[📥 Download Cameras.json](file={path})" + def download_model(): path = os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply") return f"[📥 Download Pretrained Model (.ply)](file={path})" + # Full pipeline helpers def run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024): tmpdirname = tempfile.mkdtemp() scene_dir = os.path.join(tmpdirname, "scene") os.makedirs(scene_dir, exist_ok=True) - selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size) + selected_frames = process_input_for_colmap( + input_path, num_ref_views, scene_dir, max_size + ) run_colmap_on_scene(scene_dir) return selected_frames, scene_dir -# Preprocess Input -def process_input(input_path, num_ref_views, output_dir, max_size=1024): - if isinstance(input_path, (str, os.PathLike)): - if os.path.isdir(input_path): - frames = [] - for img_file in sorted(os.listdir(input_path)): - if img_file.lower().endswith(('jpg', 'jpeg', 'png')): - img = Image.open(os.path.join(output_dir, img_file)).convert('RGB') - img.thumbnail((1024, 1024)) - frames.append(np.array(img)) - else: - frames = read_video_frames(video_input=input_path, max_size=max_size) - else: - frames = read_video_frames(video_input=input_path, max_size=max_size) - - frames_scores = preprocess_frames(frames) - selected_frames_indices = select_optimal_frames(scores=frames_scores, - k=min(num_ref_views, len(frames))) - selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices] - - save_frames_to_scene_dir(frames=selected_frames, scene_dir=output_dir) - return selected_frames - -def preprocess_input(input_path, num_ref_views, max_size=1024): - tmpdirname = tempfile.mkdtemp() - scene_dir = os.path.join(tmpdirname, "scene") - os.makedirs(scene_dir, exist_ok=True) - selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size) - run_colmap_on_scene(scene_dir) - return selected_frames, scene_dir def start_training(scene_dir, num_ref_views, num_corrs, num_steps): - return capture_logs(run_training_pipeline, scene_dir, num_ref_views, num_corrs, num_steps) - + return capture_logs( + run_training_pipeline, scene_dir, num_ref_views, num_corrs, num_steps + ) + # Gradio App with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=6): - gr.Markdown(""" + gr.Markdown( + """ ## 📄 EDGS: Eliminating Densification for Efficient Convergence of 3DGS 🔗 Project Page - """, elem_id="header") + """, + elem_id="header", + ) - gr.Markdown(""" + gr.Markdown( + """ ### 🛠️ How to Use This Demo 1. Upload a **front-facing video** or **a folder of images** of a **static** scene. @@ -306,37 +360,52 @@ def start_training(scene_dir, num_ref_views, num_corrs, num_steps): ✅ Best for scenes with small camera motion. ❗ For full 360° or large-scale scenes, we recommend the Colab version (see project page). - """, elem_id="quickstart") - + """, + elem_id="quickstart", + ) scene_dir_state = gr.State() ply_model_state = gr.State() with gr.Row(): with gr.Column(scale=2): - input_file = gr.File(label="Upload Video or Images", - file_types=[".mp4", ".avi", ".mov", ".png", ".jpg", ".jpeg"], - file_count="multiple") + input_file = gr.File( + label="Upload Video or Images", + file_types=[".mp4", ".avi", ".mov", ".png", ".jpg", ".jpeg"], + file_count="multiple", + ) gr.Examples( - examples = [ + examples=[ [["assets/examples/video_bakery.mp4"]], [["assets/examples/video_flowers.mp4"]], [["assets/examples/video_fruits.mp4"]], [["assets/examples/video_plant.mp4"]], [["assets/examples/video_salad.mp4"]], [["assets/examples/video_tram.mp4"]], - [["assets/examples/video_tulips.mp4"]] - ], + [["assets/examples/video_tulips.mp4"]], + ], inputs=[input_file], label="🎞️ ALternatively, try an Example Video", - examples_per_page=4 + examples_per_page=4, + ) + ref_slider = gr.Slider( + 4, 32, value=16, step=1, label="Number of Reference Views" + ) + corr_slider = gr.Slider( + 5000, + 30000, + value=20000, + step=1000, + label="Correspondences per Reference View", + ) + fit_steps_slider = gr.Slider( + 100, 5000, value=400, step=100, label="Number of optimization steps" ) - ref_slider = gr.Slider(4, 32, value=16, step=1, label="Number of Reference Views") - corr_slider = gr.Slider(5000, 30000, value=20000, step=1000, label="Correspondences per Reference View") - fit_steps_slider = gr.Slider(100, 5000, value=400, step=100, label="Number of optimization steps") preprocess_button = gr.Button("📸 Preprocess Input") start_button = gr.Button("🚀 Start Reconstruction", interactive=False) - gallery = gr.Gallery(label="Selected Reference Views", columns=4, height=300) + gallery = gr.Gallery( + label="Selected Reference Views", columns=4, height=300 + ) with gr.Column(scale=3): gr.Markdown("### 🏋️ Training Visualization") @@ -351,43 +420,104 @@ def start_training(scene_dir, num_ref_views, num_corrs, num_steps): gr.Markdown("### 📦 Output Files") with gr.Row(height=50): with gr.Column(): - #gr.Markdown(value=f"[📥 Download .ply](file/point_cloud_final.ply)") + # gr.Markdown(value=f"[📥 Download .ply](file/point_cloud_final.ply)") download_cameras_button = gr.Button("📥 Download Cameras.json") download_cameras_file = gr.File(label="📄 Cameras.json") with gr.Column(): - download_model_button = gr.Button("📥 Download Pretrained Model (.ply)") + download_model_button = gr.Button( + "📥 Download Pretrained Model (.ply)" + ) download_model_file = gr.File(label="📄 Pretrained Model (.ply)") log_output_box = gr.Textbox(label="🖥️ Log", lines=10, interactive=False) - def on_preprocess_click(input_file, num_ref_views): - images, scene_dir = preprocess_input(input_file, num_ref_views) - return gr.update(value=[x[...,::-1] for x in images]), scene_dir, gr.update(interactive=True) + def on_preprocess_click(input_file_obj, num_ref_views_val): + """ + Handles the preprocess button click. + Calls the main preprocessing orchestrator and updates the UI. + """ + if input_file_obj is None: + # Handles case where no file is uploaded if input_file component is not required + # or if the user clears the selection. + # For gr.Examples, input_file_obj will be a list containing a list of paths. + # For direct upload with file_count="multiple", it's a list of file objects. + # For direct upload with file_count="single", it's a single file object. + # orchestrate_video_to_colmap_scene should be robust to these. + gr.Warning("Please upload a file or select an example.") + return None, None, gr.update(interactive=False) + + selected_bgr_frames, scene_dir = orchestrate_video_to_colmap_scene( + input_path=input_file_obj, # Pass the raw Gradio file object(s) + num_ref_views=num_ref_views_val, + max_size=1024, + base_work_dir="./gradio_processed_scenes", # Or configure as needed + ) + + if not scene_dir: # Indicates preprocessing failed + gr.Error("Preprocessing failed. Please check the logs or input file.") + return ( + None, + None, + gr.update(interactive=False), + ) # Keep gallery empty, scene_dir None, button disabled + + # Convert BGR numpy arrays to RGB for Gradio gallery. + # gr.Gallery can display a list of NumPy arrays (H, W, C) or PIL Images. + # Assuming selected_bgr_frames contains BGR NumPy arrays. + gallery_display_images = [] + if selected_bgr_frames: + gallery_display_images = [ + frame[..., ::-1] + for frame in selected_bgr_frames + if isinstance(frame, np.ndarray) + ] + + return ( + gr.update(value=gallery_display_images), + scene_dir, # Update the scene_dir_state + gr.update(interactive=True), # Enable the 'Start Reconstruction' button + ) def on_start_click(scene_dir, num_ref_views, num_corrs, num_steps): - (video_path, ply_path), logs = start_training(scene_dir, num_ref_views, num_corrs, num_steps) + (video_path, ply_path), logs = start_training( + scene_dir, num_ref_views, num_corrs, num_steps + ) return video_path, ply_path, logs preprocess_button.click( fn=on_preprocess_click, inputs=[input_file, ref_slider], - outputs=[gallery, scene_dir_state, start_button] + outputs=[gallery, scene_dir_state, start_button], ) start_button.click( fn=on_start_click, inputs=[scene_dir_state, ref_slider, corr_slider, fit_steps_slider], - outputs=[video_output, model3d_viewer, log_output_box] + outputs=[video_output, model3d_viewer, log_output_box], ) - render_all_views_button.click(fn=render_all_views, inputs=[scene_dir_state], outputs=[rendered_video_output]) - render_circular_path_button.click(fn=render_circular_path, inputs=[scene_dir_state], outputs=[rendered_video_output]) - - download_cameras_button.click(fn=lambda: os.path.join(MODEL_PATH, "cameras.json"), inputs=[], outputs=[download_cameras_file]) - download_model_button.click(fn=lambda: os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"), inputs=[], outputs=[download_model_file]) + render_all_views_button.click( + fn=render_all_views, inputs=[scene_dir_state], outputs=[rendered_video_output] + ) + render_circular_path_button.click( + fn=render_circular_path, + inputs=[scene_dir_state], + outputs=[rendered_video_output], + ) + download_cameras_button.click( + fn=lambda: os.path.join(MODEL_PATH, "cameras.json"), + inputs=[], + outputs=[download_cameras_file], + ) + download_model_button.click( + fn=lambda: os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"), + inputs=[], + outputs=[download_model_file], + ) - gr.Markdown(""" + gr.Markdown( + """ --- ### 📖 Detailed Overview @@ -413,12 +543,22 @@ def on_start_click(scene_dir, num_ref_views, num_corrs, num_steps): --- Preloaded models coming soon. (TODO) - """, elem_id="details") + """, + elem_id="details", + ) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Launch Gradio demo for EDGS preprocessing and 3D viewing.") - parser.add_argument("--port", type=int, default=7860, help="Port to launch the Gradio app on.") - parser.add_argument("--no_share", action='store_true', help="Disable Gradio sharing and assume local access (default: share=True)") + parser = argparse.ArgumentParser( + description="Launch Gradio demo for EDGS preprocessing and 3D viewing." + ) + parser.add_argument( + "--port", type=int, default=7860, help="Port to launch the Gradio app on." + ) + parser.add_argument( + "--no_share", + action="store_true", + help="Disable Gradio sharing and assume local access (default: share=True)", + ) args = parser.parse_args() demo.launch(server_name="0.0.0.0", server_port=args.port, share=not args.no_share) diff --git a/install.sh b/script/install.sh similarity index 100% rename from install.sh rename to script/install.sh diff --git a/metrics.py b/script/metrics.py similarity index 96% rename from metrics.py rename to script/metrics.py index 511d106..4c3f86c 100644 --- a/metrics.py +++ b/script/metrics.py @@ -107,6 +107,12 @@ def evaluate(model_paths): if __name__ == "__main__": device = torch.device("cuda:0") torch.cuda.set_device(device) + + # Set torch hub cache directory to model/ folder to avoid re-downloading + model_dir = Path("model") + model_dir.mkdir(exist_ok=True) + torch.hub.set_dir(str(model_dir)) + lpips_fn = lpips.LPIPS(net='vgg').to(device) # Set up command line argument parser parser = ArgumentParser(description="Training script parameters") diff --git a/script/train.py b/script/train.py new file mode 100644 index 0000000..0ba97a4 --- /dev/null +++ b/script/train.py @@ -0,0 +1,73 @@ +import os +import sys +from argparse import Namespace + +import hydra +import omegaconf +import wandb + +# Add the project root directory to sys.path +# so that modules from 'source' can be imported. +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from source.trainer import EDGSTrainer +from source.utils_aux import set_seed + + +@hydra.main(config_path="../configs", config_name="train", version_base="1.2") +def main(cfg: omegaconf.DictConfig): + _ = wandb.init( + entity=cfg.wandb.entity, + project=cfg.wandb.project, + config=omegaconf.OmegaConf.to_container( + cfg, resolve=True, throw_on_missing=True + ), + tags=[cfg.wandb.tag], + name=cfg.wandb.name, + mode=cfg.wandb.mode, + ) + omegaconf.OmegaConf.resolve(cfg) + set_seed(cfg.seed) + + # Init output folder + print("Output folder: {}".format(cfg.gs.dataset.output_path)) + os.makedirs(cfg.gs.dataset.output_path, exist_ok=True) + with open(os.path.join(cfg.gs.dataset.output_path, "cfg_args"), "w") as cfg_log_f: + params = { + "sh_degree": 3, + "source_path": cfg.gs.dataset.source_path, + "model_path": cfg.gs.dataset.output_path, + "images": cfg.gs.dataset.images, + "depths": "", + "resolution": -1, + "_white_background": cfg.gs.dataset.white_background, + "train_test_exp": False, + "data_device": cfg.gs.dataset.data_device, + "eval": False, + "convert_SHs_python": False, + "compute_cov3D_python": False, + "debug": False, + "antialiasing": False, + } + cfg_log_f.write(str(Namespace(**params))) + + # Init both agents + gs = hydra.utils.instantiate(cfg.gs) + + # Init trainer and launch training + trainer = EDGSTrainer(GS=gs, training_config=cfg.gs.opt, device=cfg.device) + + trainer.load_checkpoints(cfg.load) + trainer.timer.start() + trainer.init_with_corr(cfg.init_wC) + trainer.train(cfg.train) + + # All done + wandb.finish() + print("\nTraining complete.") + + +if __name__ == "__main__": + main() diff --git a/source/corr_init.py b/source/corr_init.py index 09c94a2..644a832 100644 --- a/source/corr_init.py +++ b/source/corr_init.py @@ -15,6 +15,7 @@ import torch.nn.functional as F from romatch import roma_outdoor, roma_indoor +from .model_utils import load_roma_model_with_custom_cache from utils.sh_utils import RGB2SH from romatch.utils import get_tuple_transform_ops @@ -532,9 +533,9 @@ def init_gaussians_with_corr(gaussians, scene, cfg, device, verbose = False, rom """ if roma_model is None: if cfg.roma_model == "indoors": - roma_model = roma_indoor(device=device) + roma_model = load_roma_model_with_custom_cache(model_type="indoor", device=device) else: - roma_model = roma_outdoor(device=device) + roma_model = load_roma_model_with_custom_cache(model_type="outdoor", device=device) roma_model.upsample_preds = False roma_model.symmetric = False M = cfg.matches_per_ref @@ -753,9 +754,9 @@ def init_gaussians_with_corr_fast(gaussians, scene, cfg, device, verbose=False, if roma_model is None: if cfg.roma_model == "indoors": - roma_model = roma_indoor(device=device) + roma_model = load_roma_model_with_custom_cache(model_type="indoor", device=device) else: - roma_model = roma_outdoor(device=device) + roma_model = load_roma_model_with_custom_cache(model_type="outdoor", device=device) roma_model.upsample_preds = False roma_model.symmetric = False diff --git a/source/model_utils.py b/source/model_utils.py new file mode 100644 index 0000000..ff47e80 --- /dev/null +++ b/source/model_utils.py @@ -0,0 +1,63 @@ +import os +import sys +import torch + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +MODEL_DIR = os.path.join(PROJECT_ROOT, "model") +os.makedirs(MODEL_DIR, exist_ok=True) + +# Add RoMa path +sys.path.append(os.path.join(PROJECT_ROOT, "submodules", "RoMa")) + +def load_roma_model_with_custom_cache(model_type="indoor", device="cuda"): + """ + Load RoMa model with custom model cache directory instead of torch hub cache + + Args: + model_type: "indoor" or "outdoor" + device: Device to load the model on + + Returns: + RoMa model instance + """ + # Import RoMa functions + from romatch import roma_indoor, roma_outdoor + + # Define URLs for models + weight_urls = { + "romatch": { + "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth", + "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth", + }, + "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", + } + + def download_model_to_custom_dir(url, filename): + """Download model to custom model directory instead of torch hub cache""" + model_path = os.path.join(MODEL_DIR, filename) + if os.path.exists(model_path): + print(f"Loading cached model from {model_path}") + return torch.load(model_path, map_location='cpu') + else: + print(f"Downloading model to {model_path}") + state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu') + torch.save(state_dict, model_path) + return state_dict + + # Download models to custom directory + roma_weights = download_model_to_custom_dir( + weight_urls["romatch"][model_type], + f"roma_{model_type}.pth" + ) + dinov2_weights = download_model_to_custom_dir( + weight_urls["dinov2"], + "dinov2_vitl14_pretrain.pth" + ) + + # Load model with custom weights + if model_type == "indoor": + model = roma_indoor(device=device, weights=roma_weights, dinov2_weights=dinov2_weights) + else: + model = roma_outdoor(device=device, weights=roma_weights, dinov2_weights=dinov2_weights) + + return model \ No newline at end of file diff --git a/source/trainer.py b/source/trainer.py index 7fbd28a..f366dd0 100644 --- a/source/trainer.py +++ b/source/trainer.py @@ -5,6 +5,8 @@ from source.networks import Warper3DGS import wandb import sys +import os +from pathlib import Path sys.path.append('./submodules/gaussian-splatting/') import lpips @@ -50,6 +52,12 @@ def __init__(self, # Logs in the format {step:{"loss1":loss1_value, "loss2":loss2_value}} self.logs_losses = {} + + # Set torch hub cache directory to model/ folder to avoid re-downloading + model_dir = Path("model") + model_dir.mkdir(exist_ok=True) + torch.hub.set_dir(str(model_dir)) + self.lpips = lpips.LPIPS(net='vgg').to(device) self.device = device self.timer = Timer() diff --git a/source/utils_preprocess.py b/source/utils_preprocess.py index 7c6dab3..94a5195 100644 --- a/source/utils_preprocess.py +++ b/source/utils_preprocess.py @@ -1,41 +1,42 @@ # This file contains function for video or image collection preprocessing. # For video we do the preprocessing and select k sharpest frames. -# Afterwards scene is constructed +# Afterwards scene is constructed +import os +import time + import cv2 import numpy as np -from tqdm import tqdm import pycolmap -import os -import time -import tempfile -from moviepy import VideoFileClip from matplotlib import pyplot as plt +from moviepy import VideoFileClip from PIL import Image -import cv2 from tqdm import tqdm -WORKDIR = "../outputs/" - def get_rotation_moviepy(video_path): clip = VideoFileClip(video_path) rotation = 0 try: - displaymatrix = clip.reader.infos['inputs'][0]['streams'][2]['metadata'].get('displaymatrix', '') - if 'rotation of' in displaymatrix: - angle = float(displaymatrix.strip().split('rotation of')[-1].split('degrees')[0]) + displaymatrix = clip.reader.infos["inputs"][0]["streams"][2]["metadata"].get( + "displaymatrix", "" + ) + if "rotation of" in displaymatrix: + angle = float( + displaymatrix.strip().split("rotation of")[-1].split("degrees")[0] + ) rotation = int(angle) % 360 - + except Exception as e: print(f"No displaymatrix rotation found: {e}") clip.reader.close() - #if clip.audio: + # if clip.audio: # clip.audio.reader.close_proc() return rotation + def resize_max_side(frame, max_size): h, w = frame.shape[:2] scale = max_size / max(h, w) @@ -43,6 +44,7 @@ def resize_max_side(frame, max_size): frame = cv2.resize(frame, (int(w * scale), int(h * scale))) return frame + def read_video_frames(video_input, k=1, max_size=1024): """ Extracts every k-th frame from a video or list of images, resizes to max size, and returns frames as list. @@ -58,7 +60,9 @@ def read_video_frames(video_input, k=1, max_size=1024): # Handle list of image files (not single video in a list) if isinstance(video_input, list): # If it's a single video in a list, treat it as video - if len(video_input) == 1 and video_input[0].name.endswith(('.mp4', '.avi', '.mov')): + if len(video_input) == 1 and video_input[0].name.endswith( + (".mp4", ".avi", ".mov") + ): video_input = video_input[0] # unwrap single video file else: # Treat as list of images @@ -66,18 +70,19 @@ def read_video_frames(video_input, k=1, max_size=1024): for img_file in video_input: img = Image.open(img_file.name).convert("RGB") img.thumbnail((max_size, max_size)) - frames.append(np.array(img)[...,::-1]) + frames.append(np.array(img)[..., ::-1]) return frames # Handle file-like or path - if hasattr(video_input, 'name'): + if hasattr(video_input, "name"): video_path = video_input.name elif isinstance(video_input, (str, os.PathLike)): video_path = str(video_input) else: - raise ValueError("Unsupported video input type. Must be a filepath, file-like object, or list of images.") + raise ValueError( + "Unsupported video input type. Must be a filepath, file-like object, or list of images." + ) - cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Error: Could not open video {video_path}.") @@ -97,20 +102,21 @@ def read_video_frames(video_input, k=1, max_size=1024): scale = max(h, w) / max_size if scale > 1: frame = cv2.resize(frame, (int(w / scale), int(h / scale))) - frames.append(frame[...,[2,1,0]]) + frames.append(frame[..., [2, 1, 0]]) pbar.update(1) frame_count += 1 cap.release() return frames + def resize_max_side(frame, max_size): """ Resizes the frame so that its largest side equals max_size, maintaining aspect ratio. """ height, width = frame.shape[:2] max_dim = max(height, width) - + if max_dim <= max_size: return frame # No need to resize @@ -118,41 +124,47 @@ def resize_max_side(frame, max_size): new_width = int(width * scale) new_height = int(height * scale) - resized_frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA) + resized_frame = cv2.resize( + frame, (new_width, new_height), interpolation=cv2.INTER_AREA + ) return resized_frame - def variance_of_laplacian(image): - # compute the Laplacian of the image and then return the focus - # measure, which is simply the variance of the Laplacian - return cv2.Laplacian(image, cv2.CV_64F).var() - -def process_all_frames(IMG_FOLDER = '/scratch/datasets/hq_data/night2_all_frames', - to_visualize=False, - save_images=True): + # compute the Laplacian of the image and then return the focus + # measure, which is simply the variance of the Laplacian + return cv2.Laplacian(image, cv2.CV_64F).var() + + +def process_all_frames( + IMG_FOLDER="/scratch/datasets/hq_data/night2_all_frames", + to_visualize=False, + save_images=True, +): dict_scores = {} - for idx, img_name in tqdm(enumerate(sorted([x for x in os.listdir(IMG_FOLDER) if '.png' in x]))): - - img = cv2.imread(os.path.join(IMG_FOLDER, img_name))#[250:, 100:] + for idx, img_name in tqdm( + enumerate(sorted([x for x in os.listdir(IMG_FOLDER) if ".png" in x])) + ): + img = cv2.imread(os.path.join(IMG_FOLDER, img_name)) # [250:, 100:] gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - fm = variance_of_laplacian(gray) + \ - variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.75, fy=0.75)) + \ - variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.5, fy=0.5)) + \ - variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.25, fy=0.25)) + fm = ( + variance_of_laplacian(gray) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.75, fy=0.75)) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.5, fy=0.5)) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.25, fy=0.25)) + ) if to_visualize: plt.figure() plt.title(f"Laplacian score: {fm:.2f}") - plt.imshow(img[..., [2,1,0]]) + plt.imshow(img[..., [2, 1, 0]]) plt.show() - dict_scores[idx] = {"idx" : idx, - "img_name" : img_name, - "score" : fm} + dict_scores[idx] = {"idx": idx, "img_name": img_name, "score": fm} if save_images: dict_scores[idx]["img"] = img - + return dict_scores + def select_optimal_frames(scores, k): """ Selects a minimal subset of frames while ensuring no gaps exceed k. @@ -165,12 +177,14 @@ def select_optimal_frames(scores, k): list of int: Indices of selected frames. """ n = len(scores) - selected = [0, n-1] + selected = [0, n - 1] i = 0 # Start at the first frame while i < n: # Find the best frame to select within the next k frames - best_idx = max(range(i, min(i + k + 1, n)), key=lambda x: scores[x], default=None) + best_idx = max( + range(i, min(i + k + 1, n)), key=lambda x: scores[x], default=None + ) if best_idx is None: break # No more frames left @@ -187,6 +201,7 @@ def variance_of_laplacian(image): """ return cv2.Laplacian(image, cv2.CV_64F).var() + def preprocess_frames(frames, verbose=False): """ Compute sharpness scores for a list of frames using multi-scale Laplacian variance. @@ -204,12 +219,12 @@ def preprocess_frames(frames, verbose=False): gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) fm = ( - variance_of_laplacian(gray) + - variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.75, fy=0.75)) + - variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.5, fy=0.5)) + - variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.25, fy=0.25)) + variance_of_laplacian(gray) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.75, fy=0.75)) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.5, fy=0.5)) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.25, fy=0.25)) ) - + if verbose: print(f"Frame {idx}: Sharpness Score = {fm:.2f}") @@ -217,6 +232,7 @@ def preprocess_frames(frames, verbose=False): return scores + def select_optimal_frames(scores, k): """ Selects k frames by splitting into k segments and picking the sharpest frame from each. @@ -226,7 +242,7 @@ def select_optimal_frames(scores, k): k (int): Number of frames to select. Returns: - list of int: Indices of selected frames. + list of int: Indices of selected frames. """ n = len(scores) selected_indices = [] @@ -236,15 +252,16 @@ def select_optimal_frames(scores, k): start = i * segment_size end = (i + 1) * segment_size if i < k - 1 else n # Last chunk may be larger segment_scores = scores[start:end] - + if len(segment_scores) == 0: continue # Safety check if some segment is empty - + best_in_segment = start + np.argmax(segment_scores) selected_indices.append(best_in_segment) return sorted(selected_indices) + def save_frames_to_scene_dir(frames, scene_dir): """ Saves a list of frames into the target scene directory under 'images/' subfolder. @@ -257,7 +274,9 @@ def save_frames_to_scene_dir(frames, scene_dir): os.makedirs(images_dir, exist_ok=True) for idx, frame in enumerate(frames): - filename = os.path.join(images_dir, f"{idx:08d}.png") # 00000000.png, 00000001.png, etc. + filename = os.path.join( + images_dir, f"{idx:08d}.png" + ) # 00000000.png, 00000001.png, etc. cv2.imwrite(filename, frame) print(f"Saved {len(frames)} frames to {images_dir}") @@ -269,7 +288,7 @@ def run_colmap_on_scene(scene_dir): Args: scene_dir (str): Path to scene directory containing 'images' folder. - + TODO: if the function hasn't managed to match all the frames either increase image size, increase number of features or just remove those frames from the folder scene_dir/images """ @@ -280,7 +299,7 @@ def run_colmap_on_scene(scene_dir): database_path = os.path.join(scene_dir, "database.db") sparse_path = os.path.join(scene_dir, "sparse") image_dir = os.path.join(scene_dir, "images") - + # Make sure output directories exist os.makedirs(sparse_path, exist_ok=True) @@ -291,7 +310,7 @@ def run_colmap_on_scene(scene_dir): sift_options={ "max_num_features": 512 * 2, "max_image_size": 512 * 1, - } + }, ) print(f"Finished feature extraction in {(time.time() - start_time):.2f}s.") @@ -325,10 +344,160 @@ def run_colmap_on_scene(scene_dir): reconstruction = pycolmap.Reconstruction(recon_path) for cam in reconstruction.cameras.values(): - cam.model = 'SIMPLE_PINHOLE' + cam.model = "SIMPLE_PINHOLE" cam.params = cam.params[:3] # Keep only [f, cx, cy] reconstruction.write(recon_path) print(f"Total pipeline time: {(time.time() - start_time):.2f}s.") + +def process_input_for_colmap(input_path, num_ref_views, output_dir, max_size=1024): + """ + Helper function to read frames from video or image folder, select optimal ones, + and save them to the output_dir/images. + This is based on process_input from gradio_demo.py. + Renamed to avoid potential confusion if 'process_input' is too generic. + """ + frames_to_save_in_scene_dir = [] + if isinstance(input_path, (str, os.PathLike)): # If input_path is a path string + if os.path.isdir(input_path): # If it's a directory of images + print(f"Processing image directory: {input_path}") + raw_frames = [] + image_files = sorted( + [ + f + for f in os.listdir(input_path) + if f.lower().endswith(("jpg", "jpeg", "png")) + ] + ) + for img_file in image_files: + img = Image.open(os.path.join(input_path, img_file)).convert("RGB") + # Resize if necessary, similar to video frames + width, height = img.size + if max(width, height) > max_size: + scale = max_size / max(width, height) + new_width = int(width * scale) + new_height = int(height * scale) + img = img.resize((new_width, new_height), Image.LANCZOS) + raw_frames.append(np.array(img)) + else: # If it's a single video file path + print(f"Processing video file: {input_path}") + raw_frames = read_video_frames(video_input=input_path, max_size=max_size) + elif hasattr( + input_path, "name" + ): # If input_path is a file-like object (e.g., from Gradio upload) + print(f"Processing uploaded video file: {input_path.name}") + raw_frames = read_video_frames(video_input=input_path.name, max_size=max_size) + else: + raise ValueError(f"Unsupported input_path type: {type(input_path)}") + + if not raw_frames: + print("No frames extracted or read.") + return [] + + frames_scores = preprocess_frames( + raw_frames + ) # Assuming preprocess_frames takes list of numpy arrays + selected_frames_indices = select_optimal_frames( + scores=frames_scores, k=min(num_ref_views, len(raw_frames)) + ) + frames_to_save_in_scene_dir = [ + raw_frames[frame_idx] for frame_idx in selected_frames_indices + ] + + # The 'output_dir' here is the scene_dir where 'images' subfolder will be created + save_frames_to_scene_dir(frames=frames_to_save_in_scene_dir, scene_dir=output_dir) + return frames_to_save_in_scene_dir # Returns the list of selected frame data (numpy arrays) + + +def orchestrate_video_to_colmap_scene( + input_path, + num_ref_views, + max_size=1024, + base_work_dir="../outputs/processed_scenes", +): + """ + Orchestrates the full video/image folder preprocessing pipeline: + 1. Creates a temporary scene directory. + 2. Reads frames, selects optimal ones, saves them. + 3. Runs COLMAP on the scene. + Args: + input_path (str or file-like): Path string, a Gradio file object, or a list (e.g., from gr.Examples). + num_ref_views (int): Number of reference views to select. + max_size (int): Maximum size for width or height after resizing. + base_work_dir (str): Base directory for temporary scene directories. + Returns: + the list of selected frame image data and the path to the COLMAP processed scene directory. + This is based on preprocess_input from gradio_demo.py. + """ + actual_input_path_str = None + input_name_part = "temp_scene" # Default + + if hasattr(input_path, "name") and isinstance( + input_path.name, str + ): # Gradio file object + actual_input_path_str = input_path.name + input_name_part = os.path.splitext(os.path.basename(input_path.name))[0] + elif isinstance(input_path, (str, os.PathLike)): # Direct path string + actual_input_path_str = str(input_path) + input_name_part = os.path.splitext(os.path.basename(input_path))[0] + elif ( + isinstance(input_path, list) and input_path + ): # Handle list: take the first item. + # gr.Examples often wraps the path in another list, e.g., [['path/to/example.mp4']] + # So, we might need to unwrap it. + first_item_candidate = input_path[0] + if ( + isinstance(first_item_candidate, list) and first_item_candidate + ): # Check for nested list + first_item = first_item_candidate[0] + else: + first_item = first_item_candidate + + if hasattr(first_item, "name") and isinstance( + first_item.name, str + ): # Gradio file object in list + actual_input_path_str = first_item.name + input_name_part = os.path.splitext(os.path.basename(first_item.name))[0] + elif isinstance(first_item, (str, os.PathLike)): # Path string in list + actual_input_path_str = str(first_item) + input_name_part = os.path.splitext(os.path.basename(first_item))[0] + else: + print(f"Warning: Unsupported item type in input list: {type(first_item)}") + return [], None + else: + print(f"Error: Unsupported input_path type: {type(input_path)}") + return [], None + + if not actual_input_path_str: + print("Error: Could not determine a valid input file path.") + return [], None + + print(f"Orchestrating COLMAP scene from: {actual_input_path_str}") + + # Using a structured output directory instead of pure tempfile.mkdtemp for easier inspection + # scene_dir_parent = tempfile.mkdtemp() # Original approach + + # Ensure base_work_dir exists + os.makedirs(base_work_dir, exist_ok=True) + # Create a unique subdirectory within base_work_dir + timestamp = time.strftime("%Y%m%d-%H%M%S") + scene_dir = os.path.join(base_work_dir, f"{input_name_part}_{timestamp}") + + os.makedirs(scene_dir, exist_ok=True) + print(f"Created scene directory for COLMAP: {scene_dir}") + + selected_frames_data = process_input_for_colmap( + actual_input_path_str, num_ref_views, scene_dir, max_size + ) + if not selected_frames_data: + print(f"Frame processing failed for {input_path}. Aborting COLMAP.") + # Optionally clean up scene_dir if it's truly temporary and processing failed + # shutil.rmtree(scene_dir) + return [], None + + run_colmap_on_scene(scene_dir) # This function should create scene_dir/sparse/0 + + print(f"COLMAP processing complete for {scene_dir}") + return selected_frames_data, scene_dir diff --git a/train.py b/train.py deleted file mode 100644 index 646409a..0000000 --- a/train.py +++ /dev/null @@ -1,63 +0,0 @@ -import os -from source.trainer import EDGSTrainer -from source.utils_aux import set_seed -import omegaconf -import wandb -import hydra -from argparse import Namespace -from omegaconf import OmegaConf - - -@hydra.main(config_path="configs", config_name="train", version_base="1.2") -def main(cfg: omegaconf.DictConfig): - _ = wandb.init(entity=cfg.wandb.entity, - project=cfg.wandb.project, - config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True), - tags=[cfg.wandb.tag], - name = cfg.wandb.name, - mode = cfg.wandb.mode) - omegaconf.OmegaConf.resolve(cfg) - set_seed(cfg.seed) - - # Init output folder - print("Output folder: {}".format(cfg.gs.dataset.model_path)) - os.makedirs(cfg.gs.dataset.model_path, exist_ok=True) - with open(os.path.join(cfg.gs.dataset.model_path, "cfg_args"), 'w') as cfg_log_f: - params = { - "sh_degree": 3, - "source_path": cfg.gs.dataset.source_path, - "model_path": cfg.gs.dataset.model_path, - "images": cfg.gs.dataset.images, - "depths": "", - "resolution": -1, - "_white_background": cfg.gs.dataset.white_background, - "train_test_exp": False, - "data_device": cfg.gs.dataset.data_device, - "eval": False, - "convert_SHs_python": False, - "compute_cov3D_python": False, - "debug": False, - "antialiasing": False - } - cfg_log_f.write(str(Namespace(**params))) - - # Init both agents - gs = hydra.utils.instantiate(cfg.gs) - - # Init trainer and launch training - trainer = EDGSTrainer(GS=gs, - training_config=cfg.gs.opt, - device=cfg.device) - - trainer.load_checkpoints(cfg.load) - trainer.timer.start() - trainer.init_with_corr(cfg.init_wC) - trainer.train(cfg.train) - - # All done - wandb.finish() - print("\nTraining complete.") - -if __name__ == "__main__": - main() -