diff --git a/ml/dataset/npz_sequence.py b/ml/dataset/npz_sequence.py index 5179d5c..d8727c7 100644 --- a/ml/dataset/npz_sequence.py +++ b/ml/dataset/npz_sequence.py @@ -225,8 +225,8 @@ class FluidNPZSequenceDataset(Dataset): def __init__( self, npz_dir: str | Path, + split: str, normalize: bool = False, - seq_indices: list[int] | None = None, is_training: bool = False, augmentation_config: dict | None = None, preload: bool = False, @@ -242,18 +242,15 @@ def __init__( self.enable_augmentation = self.augmentation_config.get("enable_augmentation", False) self.flip_probability = self.augmentation_config.get("flip_probability", 0.5) - npz_dir_path = Path(npz_dir) - all_seq_paths: list[Path] = sorted( + npz_dir_path = Path(npz_dir) / split + if not npz_dir_path.exists(): + raise FileNotFoundError(f"Split directory not found: {npz_dir_path}") + + self.seq_paths: list[Path] = sorted( [f for f in npz_dir_path.iterdir() if f.name.startswith("seq_") and f.name.endswith(".npz")] ) - if not all_seq_paths: - raise FileNotFoundError(f"No seq_*.npz files found in {npz_dir}") - - # Filter to specified sequences if provided (splits) - if seq_indices is not None: - self.seq_paths = [all_seq_paths[i] for i in seq_indices] - else: - self.seq_paths = all_seq_paths + if not self.seq_paths: + raise FileNotFoundError(f"No seq_*.npz files found in {npz_dir_path}") self.num_real_sequences = len(self.seq_paths) # self.num_fake_sequences = _calculate_fake_count(self.num_real_sequences, fake_empty_pct) diff --git a/ml/scripts/eval_rollout.py b/ml/scripts/eval_rollout.py index ea79c49..044ac03 100644 --- a/ml/scripts/eval_rollout.py +++ b/ml/scripts/eval_rollout.py @@ -8,7 +8,6 @@ sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) -import yaml from config.config import PROJECT_ROOT_PATH from config.training_config import TrainingConfig, project_config @@ -20,7 +19,6 @@ _plot_rollout_degradation, _run_single_rollout, ) -from utils.data_splits import make_splits, set_seed ML_ROOT = Path(__file__).resolve().parent.parent DATA_DIR = ML_ROOT.parent / "data" / "npz" / "128" @@ -65,12 +63,14 @@ def load_from_checkpoint(checkpoint_dir: Path, device: str) -> tuple[TrainingCon def run_rollout_to_disk( model: torch.nn.Module, config: TrainingConfig, - test_indices: list[int], device: str, output_name: str, ) -> None: - all_seq_paths = sorted([p for p in DATA_DIR.iterdir() if p.name.startswith("seq_") and p.name.endswith(".npz")]) - test_paths = [all_seq_paths[i] for i in test_indices] + test_data_dir = DATA_DIR / "test" + if not test_data_dir.exists(): + raise FileNotFoundError(f"Test split directory not found: {test_data_dir}") + + test_paths = sorted([p for p in test_data_dir.iterdir() if p.name.startswith("seq_") and p.name.endswith(".npz")]) norm_scales = None if config.normalize: @@ -144,17 +144,8 @@ def main() -> None: config, model = load_from_checkpoint(checkpoint_dir, device) - with open(BASE_CONFIG_PATH) as f: - base_cfg = yaml.safe_load(f) - split_ratios = tuple(base_cfg["split_ratios"]) - split_seed = base_cfg["split_seed"] - - set_seed(split_seed) - _, _, test_idx = make_splits(DATA_DIR, split_ratios, split_seed) - print(f"Test sequences: {len(test_idx)}") - output_name = args.output_name or checkpoint_dir.name - run_rollout_to_disk(model, config, test_idx, device, output_name) + run_rollout_to_disk(model, config, device, output_name) if __name__ == "__main__": diff --git a/ml/scripts/train.py b/ml/scripts/train.py index 719f10b..fa60838 100644 --- a/ml/scripts/train.py +++ b/ml/scripts/train.py @@ -15,7 +15,7 @@ from scripts.variant_manager import VariantManager from training.test_evaluation import run_rollout_evaluation, run_test_evaluation from training.trainer import Trainer -from utils.data_splits import make_splits, set_seed +from utils.seed import set_seed def dict_to_training_config(config_dict: dict) -> TrainingConfig: @@ -75,9 +75,6 @@ def train_single_variant( set_seed(config.split_seed) npz_dir: Path = config.npz_dir / str(project_config.simulation.grid_resolution) - train_idx, val_idx, test_idx = make_splits(npz_dir, config.split_ratios, config.split_seed) - - print(f"Dataset splits: train={len(train_idx)}, val={len(val_idx)}, test={len(test_idx)}") aug_config_dict = { "enable_augmentation": config.augmentation.enable_augmentation, @@ -88,8 +85,8 @@ def train_single_variant( train_ds = FluidNPZSequenceDataset( npz_dir=npz_dir, + split="train", normalize=config.normalize, - seq_indices=train_idx, is_training=True, augmentation_config=aug_config_dict, preload=config.preload_dataset, @@ -100,8 +97,8 @@ def train_single_variant( val_rollout_steps = config.rollout_step if config.validation_use_rollout_k else 1 val_ds = FluidNPZSequenceDataset( npz_dir=npz_dir, + split="val", normalize=config.normalize, - seq_indices=val_idx, is_training=False, augmentation_config=None, preload=config.preload_dataset, @@ -243,25 +240,20 @@ def train_single_variant( val_loader=val_loader, config=config, device=config.device, - train_indices=train_idx, - val_indices=val_idx, ) trainer.train() - if test_idx: - run_test_evaluation( - model=model, - config=config, - test_indices=test_idx, - device=config.device, - ) - run_rollout_evaluation( - model=model, - config=config, - test_indices=test_idx, - device=config.device, - ) + run_test_evaluation( + model=model, + config=config, + device=config.device, + ) + run_rollout_evaluation( + model=model, + config=config, + device=config.device, + ) print(f"\nTraining complete: {config.variant.full_model_name}") print(f"Best model: {config.checkpoint_dir_variant / 'best_model.pth'}") diff --git a/ml/training/test_evaluation.py b/ml/training/test_evaluation.py index 118aa45..5a3172b 100644 --- a/ml/training/test_evaluation.py +++ b/ml/training/test_evaluation.py @@ -78,24 +78,19 @@ def _compute_batch_metrics( def run_test_evaluation( model: nn.Module, config: TrainingConfig, - test_indices: list[int], device: str, ) -> dict[str, dict[str, float]]: - if not test_indices: - print("No test indices provided, skipping test evaluation.") - return {} - npz_dir = config.npz_dir / str(project_config.simulation.grid_resolution) print(f"\n{'=' * 70}") print("TEST SET EVALUATION (best model + persistence baseline)") print(f"{'=' * 70}") - print(f"Loading test dataset ({len(test_indices)} sequences)...") + print("Loading test dataset...") test_ds = FluidNPZSequenceDataset( npz_dir=npz_dir, + split="test", normalize=config.normalize, - seq_indices=test_indices, is_training=False, augmentation_config=None, preload=config.preload_dataset, @@ -369,16 +364,16 @@ def log_artifact_flat(fig: Figure, filename: str, dpi: int = 72, artifact_path: def run_rollout_evaluation( model: nn.Module, config: TrainingConfig, - test_indices: list[int], device: str, ) -> dict[str, list[dict[str, float]]]: - if not test_indices: - print("No test indices provided, skipping rollout evaluation.") + npz_dir = config.npz_dir / str(project_config.simulation.grid_resolution) + test_npz_dir = npz_dir / "test" + + if not test_npz_dir.exists(): + print(f"Test split directory not found: {test_npz_dir}, skipping rollout evaluation.") return {} - npz_dir = config.npz_dir / str(project_config.simulation.grid_resolution) - all_seq_paths = sorted([p for p in npz_dir.iterdir() if p.name.startswith("seq_") and p.name.endswith(".npz")]) - test_paths = [all_seq_paths[i] for i in test_indices] + test_paths = sorted([p for p in test_npz_dir.iterdir() if p.name.startswith("seq_") and p.name.endswith(".npz")]) norm_scales = None if config.normalize: diff --git a/ml/training/trainer.py b/ml/training/trainer.py index 9f1a4d5..6b7aeb7 100644 --- a/ml/training/trainer.py +++ b/ml/training/trainer.py @@ -39,8 +39,6 @@ def __init__( val_loader: DataLoader, config: TrainingConfig, device: str, - train_indices: list[int] | None = None, - val_indices: list[int] | None = None, ) -> None: self.model = model self.train_loader = train_loader @@ -48,10 +46,6 @@ def __init__( self.config = config.model_copy() self.device = device - # Store indices for dataloader recreation during scheduled rollout - self.train_indices = train_indices - self.val_indices = val_indices - self.gradient_clip_norm = config.gradient_clip_norm self.gradient_clip_enabled = config.gradient_clip_enabled @@ -391,14 +385,7 @@ def _compute_rollout_loss( return total_loss, averaged_dict, final_pred def _create_dataloader_with_rollout_steps(self, K: int, is_training: bool) -> DataLoader: - if is_training: - if self.train_indices is None: - raise RuntimeError("Cannot recreate dataloader: train_indices not provided to Trainer") - indices = self.train_indices - else: - if self.val_indices is None: - raise RuntimeError("Cannot recreate dataloader: val_indices not provided to Trainer") - indices = self.val_indices + split = "train" if is_training else "val" npz_dir = ( PROJECT_ROOT_PATH @@ -408,8 +395,8 @@ def _create_dataloader_with_rollout_steps(self, K: int, is_training: bool) -> Da dataset = FluidNPZSequenceDataset( npz_dir=npz_dir, + split=split, normalize=self.config.normalize, - seq_indices=indices, is_training=is_training, augmentation_config=self.config.augmentation.model_dump() if is_training else None, preload=self.config.preload_dataset, diff --git a/ml/utils/data_splits.py b/ml/utils/data_splits.py deleted file mode 100644 index a9747f1..0000000 --- a/ml/utils/data_splits.py +++ /dev/null @@ -1,54 +0,0 @@ -import random -from pathlib import Path -from typing import cast - -import numpy as np -import torch - - -def set_seed(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - - -def make_splits( - npz_dir: Path, split_ratios: tuple[float, float, float], seed: int -) -> tuple[list[int], list[int], list[int]]: - seq_paths = sorted([p for p in npz_dir.iterdir() if p.name.startswith("seq_") and p.name.endswith(".npz")]) - n_seq = len(seq_paths) - - if n_seq < 1: - raise ValueError(f"Need at least 1 sequence, found {n_seq}") - - indices = list(range(n_seq)) - random.Random(seed).shuffle(indices) - - ratio_sum = sum(split_ratios) - if ratio_sum == 0: - normalized_ratios = (1 / 3, 1 / 3, 1 / 3) - else: - normalized_ratios = cast("tuple[float, float, float]", tuple(r / ratio_sum for r in split_ratios)) - - allocated = [int(r * n_seq) for r in normalized_ratios] - - leftover = n_seq - sum(allocated) - if leftover > 0: - fractional_parts = [(r * n_seq) % 1 for r in normalized_ratios] - sorted_indices = sorted(range(3), key=lambda i: fractional_parts[i], reverse=True) - for i in range(leftover): - allocated[sorted_indices[i]] += 1 - - n_train = allocated[0] - n_val = allocated[1] - n_test = allocated[2] - - assert n_train + n_val + n_test == n_seq - - train_idx = indices[:n_train] - val_idx = indices[n_train : n_train + n_val] - test_idx = indices[n_train + n_val :] - - return train_idx, val_idx, test_idx diff --git a/ml/utils/seed.py b/ml/utils/seed.py new file mode 100644 index 0000000..9fd8ccf --- /dev/null +++ b/ml/utils/seed.py @@ -0,0 +1,14 @@ +import random + +import numpy as np +import torch + + +def set_seed(seed: int) -> None: + """Set random seeds for reproducibility across Python, NumPy, and PyTorch.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/vdb-tools/config.py b/vdb-tools/config.py index 0cdbccb..5159999 100644 --- a/vdb-tools/config.py +++ b/vdb-tools/config.py @@ -1,6 +1,8 @@ import importlib.util from pathlib import Path +import yaml +from pydantic import BaseModel from pydantic_settings import BaseSettings, SettingsConfigDict _PROJECT_ROOT = Path(__file__).parent.parent @@ -16,6 +18,95 @@ project_config = shared_config.project_config +class SplitConfig(BaseModel): + enabled: bool = True + ratios: list[float] = [0.70, 0.15, 0.15] + names: list[str] = ["train", "val", "test"] + seed: int = 42 + + +class DistributionConfig(BaseModel): + no_emitter_pct: float = 0.05 + no_collider_pct: float = 0.50 + collider_mode_simple_threshold: float = 0.20 + collider_mode_medium_threshold: float = 0.80 + + +class ScaleConfig(BaseModel): + min: float + max: float + y_scale: float = 0.1 + + +class EmitterScaleConfig(ScaleConfig): + max_simple_mode: float = 0.2 + + +class PositionConfig(BaseModel): + x_range: tuple[float, float] + z_range: tuple[float, float] + + +class ColliderPositionConfig(BaseModel): + z_range: tuple[float, float] + + +class LargeEmitterConfig(BaseModel): + threshold: float = 0.12 + x_range: tuple[float, float] = (-0.6, 0.6) + z_position: float = -0.75 + + +class EmitterConfig(BaseModel): + count_range: tuple[int, int] + scale: EmitterScaleConfig + position: PositionConfig + large_emitter: LargeEmitterConfig = LargeEmitterConfig() + + +class ColliderModeConfig(BaseModel): + count_range: tuple[int, int] + scale: ScaleConfig | None = None + + +class ColliderConfig(BaseModel): + simple_mode: ColliderModeConfig + medium_mode: ColliderModeConfig + complex_mode: ColliderModeConfig + position: ColliderPositionConfig + + +class DomainConfig(BaseModel): + y_scale: float = 0.05 + vorticity: float = 0.05 + beta: float = 0.0 + + +class AnimationConfig(BaseModel): + max_displacement: float = 1e-5 + + +class SimulationGenerationConfig(BaseModel): + splits: SplitConfig = SplitConfig() + distribution: DistributionConfig = DistributionConfig() + emitters: EmitterConfig + colliders: ColliderConfig + domain: DomainConfig = DomainConfig() + animation: AnimationConfig = AnimationConfig() + + +def load_simulation_config() -> SimulationGenerationConfig: + config_path = Path(__file__).parent / "config" / "simulation_config.yaml" + if not config_path.exists(): + raise FileNotFoundError(f"Simulation config not found: {config_path}") + with open(config_path) as f: + config_data = yaml.safe_load(f) + return SimulationGenerationConfig(**config_data["simulation_generation"]) + + +simulation_config = load_simulation_config() + + class VDBSettings(BaseSettings): model_config = SettingsConfigDict(env_file=".env", frozen=True) @@ -28,4 +119,5 @@ class VDBSettings(BaseSettings): "PROJECT_ROOT_PATH", "project_config", "vdb_config", + "simulation_config", ] diff --git a/vdb-tools/config/simulation_config.yaml b/vdb-tools/config/simulation_config.yaml new file mode 100644 index 0000000..cb212c7 --- /dev/null +++ b/vdb-tools/config/simulation_config.yaml @@ -0,0 +1,52 @@ +simulation_generation: + splits: + enabled: true + ratios: [0.70, 0.15, 0.15] + names: ["train", "val", "test"] + seed: 42 + + distribution: + no_emitter_pct: 0.05 + no_collider_pct: 0.50 + collider_mode_simple_threshold: 0.20 + collider_mode_medium_threshold: 0.80 + + emitters: + count_range: [1, 2] + scale: + min: 0.1 + max: 0.3 + max_simple_mode: 0.2 + y_scale: 0.1 + position: + x_range: [-1.0, 1.0] + z_range: [-1.0, -0.2] + large_emitter: + threshold: 0.12 + x_range: [-0.6, 0.6] + z_position: -0.75 + + colliders: + simple_mode: + count_range: [1, 2] + scale: + min: 0.08 + max: 0.25 + y_scale: 0.1 + medium_mode: + count_range: [1, 2] + complex_mode: + count_range: [2, 3] + scale: + min: 0.3 + max: 0.8 + position: + z_range: [0.1, 1.0] + + domain: + y_scale: 0.05 + vorticity: 0.05 + beta: 0.0 + + animation: + max_displacement: 1.0e-5 diff --git a/vdb-tools/create_simulations.py b/vdb-tools/create_simulations.py index 983035f..e8b0ba9 100755 --- a/vdb-tools/create_simulations.py +++ b/vdb-tools/create_simulations.py @@ -6,48 +6,141 @@ import sys import time from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass +from math import ceil from pathlib import Path -from config import vdb_config +from config import PROJECT_ROOT_PATH, SimulationGenerationConfig, simulation_config, vdb_config -PROJECT_ROOT = Path(__file__).parent.parent BLENDER_SCRIPT = Path(__file__).parent / "blender_scripts/create_random_simulation.py" -# todo move all of that to yaml config file -# and add explanation for each parameters -NO_EMITTER_PCT = 0.05 -NO_COLLIDER_PCT = 0.50 -COLLIDER_MODE_SIMPLE_THRESHOLD = 0.20 -COLLIDER_MODE_MEDIUM_THRESHOLD = 0.80 +@dataclass +class SplitPlan: + split_name: str + total_count: int + no_emitter_count: int + no_collider_count: int + collider_simple_count: int + collider_medium_count: int + collider_complex_count: int + + +def compute_split_plan(split_count: int, split_name: str, gen_config: SimulationGenerationConfig) -> SplitPlan: + no_emitter_count = max(1, ceil(split_count * gen_config.distribution.no_emitter_pct)) + no_collider_count = max(1, ceil(no_emitter_count * gen_config.distribution.no_collider_pct)) + + sims_with_emitters = split_count - no_emitter_count + simple_thresh = gen_config.distribution.collider_mode_simple_threshold + medium_thresh = gen_config.distribution.collider_mode_medium_threshold + + collider_simple = max(1, ceil(sims_with_emitters * simple_thresh)) + collider_medium = max(1, ceil(sims_with_emitters * (medium_thresh - simple_thresh))) + collider_complex = max(0, sims_with_emitters - collider_simple - collider_medium) + + return SplitPlan( + split_name=split_name, + total_count=split_count, + no_emitter_count=no_emitter_count, + no_collider_count=no_collider_count, + collider_simple_count=collider_simple, + collider_medium_count=collider_medium, + collider_complex_count=collider_complex, + ) + + +def assign_simulations_to_splits(total_sims: int, gen_config: SimulationGenerationConfig) -> dict[str, SplitPlan]: + ratios = gen_config.splits.ratios + names = gen_config.splits.names + + ratio_sum = sum(ratios) + normalized_ratios = [r / ratio_sum for r in ratios] + allocated = [int(r * total_sims) for r in normalized_ratios] + + leftover = total_sims - sum(allocated) + if leftover > 0: + fractional = [(r * total_sims) % 1 for r in normalized_ratios] + sorted_idx = sorted(range(len(ratios)), key=lambda i: fractional[i], reverse=True) + for i in range(leftover): + allocated[sorted_idx[i]] += 1 + + return {name: compute_split_plan(count, name, gen_config) for name, count in zip(names, allocated, strict=True)} + + +def pack_config( + gen_config: SimulationGenerationConfig, sim_index: int, base_seed: int, split_name: str, sim_type: dict +) -> dict: + em = gen_config.emitters + col = gen_config.colliders + + return { + "sim_index": sim_index, + "split_name": split_name, + "seed": base_seed + sim_index, + **sim_type, + "collider_mode": sim_type.get("collider_mode", "medium"), + # Emitter config + "emitter_count_range": list(em.count_range), + "emitter_scale_min": em.scale.min, + "emitter_scale_max_simple": em.scale.max_simple_mode, + "emitter_scale_max": em.scale.max, + "emitter_y_scale": em.scale.y_scale, + "emitter_x_range": list(em.position.x_range), + "emitter_z_range": list(em.position.z_range), + "large_emitter_threshold": em.large_emitter.threshold, + "large_emitter_x_range": list(em.large_emitter.x_range), + "large_emitter_z": em.large_emitter.z_position, + # Collider config + "collider_count_medium_range": list(col.medium_mode.count_range), + "collider_count_complex_range": list(col.complex_mode.count_range), + "collider_simple_scale_min": col.simple_mode.scale.min, + "collider_simple_scale_max": col.simple_mode.scale.max, + "collider_simple_y_scale": col.simple_mode.scale.y_scale, + "collider_complex_scale_min": col.complex_mode.scale.min, + "collider_complex_scale_max": col.complex_mode.scale.max, + "collider_z_range": list(col.position.z_range), + # Domain & animation + "domain_y_scale": gen_config.domain.y_scale, + "domain_vorticity": gen_config.domain.vorticity, + "domain_beta": gen_config.domain.beta, + "anim_max_displacement": gen_config.animation.max_displacement, + } -EMITTER_COUNT_RANGE = (1, 2) -COLLIDER_COUNT_MEDIUM_RANGE = (1, 2) -COLLIDER_COUNT_COMPLEX_RANGE = (2, 3) -EMITTER_SCALE_MIN = 0.1 -EMITTER_SCALE_MAX_SIMPLE = 0.2 -EMITTER_SCALE_MAX = 0.3 -EMITTER_Y_SCALE = 0.1 +def generate_simulation_configs( + split_plans: dict[str, SplitPlan], start_index: int, base_seed: int, gen_config: SimulationGenerationConfig +) -> list[tuple[int, str, dict]]: + rng = random.Random(base_seed) + all_configs = [] + sim_index = start_index + + for split_name, plan in split_plans.items(): + sim_types = [] + + for i in range(plan.no_emitter_count): + sim_types.append( + { + "collider_mode": None, + "no_emitters": True, + "no_colliders": (i < plan.no_collider_count), + } + ) -COLLIDER_SIMPLE_SCALE_MIN = 0.08 -COLLIDER_SIMPLE_SCALE_MAX = 0.25 -COLLIDER_SIMPLE_Y_SCALE = 0.1 -COLLIDER_COMPLEX_SCALE_MIN = 0.3 -COLLIDER_COMPLEX_SCALE_MAX = 0.8 + for _ in range(plan.collider_simple_count): + sim_types.append({"collider_mode": "simple", "no_emitters": False, "no_colliders": False}) + for _ in range(plan.collider_medium_count): + sim_types.append({"collider_mode": "medium", "no_emitters": False, "no_colliders": False}) + for _ in range(plan.collider_complex_count): + sim_types.append({"collider_mode": "complex", "no_emitters": False, "no_colliders": False}) -EMITTER_X_RANGE = (-1.0, 1.0) -EMITTER_Z_RANGE = (-1.0, -0.2) -COLLIDER_Z_RANGE = (0.1, 1.0) -LARGE_EMITTER_THRESHOLD = 0.12 -LARGE_EMITTER_X_RANGE = (-0.6, 0.6) -LARGE_EMITTER_Z = -0.75 + rng.shuffle(sim_types) -DOMAIN_Y_SCALE = 0.05 -DOMAIN_VORTICITY = 0.05 -DOMAIN_BETA = 0.0 + for sim_type in sim_types: + config = pack_config(gen_config, sim_index, base_seed, split_name, sim_type) + all_configs.append((sim_index, split_name, config)) + sim_index += 1 -ANIM_MAX_DISPLACEMENT = 1e-5 + return all_configs def check_cache_exists(cache_dir: Path) -> bool: @@ -64,17 +157,15 @@ def check_cache_exists(cache_dir: Path) -> bool: def generate_simulation( sim_index: int, + split_name: str, resolution: int, frames: int, output_base_dir: Path, blend_dir: Path, - seed: int, - collider_mode: str = "medium", - no_emitters: bool = False, - no_colliders: bool = False, + config_dict: dict, ) -> tuple[bool, str]: cache_name = f"cache_{sim_index:04d}" - resolution_dir = output_base_dir / str(resolution) + resolution_dir = output_base_dir / str(resolution) / split_name cache_dir = resolution_dir / cache_name blend_resolution_dir = blend_dir / str(resolution) @@ -89,37 +180,12 @@ def generate_simulation( blend_resolution_dir.mkdir(parents=True, exist_ok=True) params = { + **config_dict, "resolution": resolution, "frames": frames, "cache_name": cache_name, "output_dir": str(cache_dir.absolute()), "blend_output_dir": str(blend_resolution_dir.absolute()), - "seed": seed, - "collider_mode": collider_mode, - "no_emitters": no_emitters, - "no_colliders": no_colliders, - "emitter_count_range": list(EMITTER_COUNT_RANGE), - "collider_count_medium_range": list(COLLIDER_COUNT_MEDIUM_RANGE), - "collider_count_complex_range": list(COLLIDER_COUNT_COMPLEX_RANGE), - "emitter_scale_min": EMITTER_SCALE_MIN, - "emitter_scale_max_simple": EMITTER_SCALE_MAX_SIMPLE, - "emitter_scale_max": EMITTER_SCALE_MAX, - "emitter_y_scale": EMITTER_Y_SCALE, - "collider_simple_scale_min": COLLIDER_SIMPLE_SCALE_MIN, - "collider_simple_scale_max": COLLIDER_SIMPLE_SCALE_MAX, - "collider_simple_y_scale": COLLIDER_SIMPLE_Y_SCALE, - "collider_complex_scale_min": COLLIDER_COMPLEX_SCALE_MIN, - "collider_complex_scale_max": COLLIDER_COMPLEX_SCALE_MAX, - "emitter_x_range": list(EMITTER_X_RANGE), - "emitter_z_range": list(EMITTER_Z_RANGE), - "collider_z_range": list(COLLIDER_Z_RANGE), - "large_emitter_threshold": LARGE_EMITTER_THRESHOLD, - "large_emitter_x_range": list(LARGE_EMITTER_X_RANGE), - "large_emitter_z": LARGE_EMITTER_Z, - "domain_y_scale": DOMAIN_Y_SCALE, - "domain_vorticity": DOMAIN_VORTICITY, - "domain_beta": DOMAIN_BETA, - "anim_max_displacement": ANIM_MAX_DISPLACEMENT, } blender_path = vdb_config.BLENDER_PATH @@ -131,7 +197,9 @@ def generate_simulation( cmd = [str(blender_path), "--background", "--python", str(BLENDER_SCRIPT), "--", json.dumps(params)] - print(f"\nSimulation {sim_index}: {cache_name} (res={resolution}, frames={frames}, seed={seed})") + print( + f"\n[{split_name}] Simulation {sim_index}: {cache_name} (res={resolution}, frames={frames}, seed={config_dict['seed']})" + ) try: result = subprocess.run(cmd, timeout=3600) @@ -155,18 +223,16 @@ def generate_simulation( def worker_wrapper(task_args: tuple) -> tuple[int, bool, str]: - sim_index, resolution, frames, output_base_dir, blend_dir, seed, collider_mode, no_emitters, no_colliders = ( - task_args - ) + sim_index, split_name, resolution, frames, output_base_dir, blend_dir, config_dict = task_args success, status = generate_simulation( - sim_index, resolution, frames, output_base_dir, blend_dir, seed, collider_mode, no_emitters, no_colliders + sim_index, split_name, resolution, frames, output_base_dir, blend_dir, config_dict ) return (sim_index, success, status) def main() -> None: parser = argparse.ArgumentParser( - description="Generate batch of randomized Blender fluid simulations", + description="Generate batch of randomized Blender fluid simulations with split-based distribution", formatter_class=argparse.RawDescriptionHelpFormatter, ) @@ -197,64 +263,50 @@ def main() -> None: sys.exit(1) base_seed = args.seed if args.seed is not None else int(time.time()) - random.seed(base_seed) + gen_config = simulation_config + + output_base_dir = PROJECT_ROOT_PATH / "data" / "blender_caches" + blend_dir = PROJECT_ROOT_PATH / "data" / "simulations" - output_base_dir = PROJECT_ROOT / "data" / "blender_caches" - blend_dir = PROJECT_ROOT / "data" / "simulations" + split_plans = assign_simulations_to_splits(args.count, gen_config) print(f"\n{'=' * 70}") - print("Blender Simulation Batch Generator") + print("Split-Based Simulation Generation") print(f"{'=' * 70}") print("Configuration:") - print(f" Count: {args.count}") + print(f" Total count: {args.count}") print(f" Resolution: {args.resolution}") print(f" Start index: {args.start_index}") print(f" Frame range: [{args.min_frames}, {args.max_frames}]") print(f" Base seed: {base_seed}") print(f" Workers: {args.workers}") - print(f" Output: {output_base_dir / str(args.resolution)}") print(f" Blender: {vdb_config.BLENDER_PATH}") + print(f"\n{'=' * 70}") + print("Split Plan:") print(f"{'=' * 70}") + for split_name, plan in split_plans.items(): + print(f"\n{split_name.upper()} ({plan.total_count} sims):") + print(f" No emitters: {plan.no_emitter_count} (no colliders: {plan.no_collider_count})") + print(f" With emitters: {plan.total_count - plan.no_emitter_count}") + print( + f" Simple: {plan.collider_simple_count}, Medium: {plan.collider_medium_count}, Complex: {plan.collider_complex_count}" + ) + print(f"{'=' * 70}\n") - successful = 0 - failed = 0 - skipped = 0 - start_time = time.time() + sim_configs = generate_simulation_configs(split_plans, args.start_index, base_seed, gen_config) + + for split_name in split_plans.keys(): + (output_base_dir / str(args.resolution) / split_name).mkdir(parents=True, exist_ok=True) tasks = [] - for i in range(args.count): - sim_index = args.start_index + i + for sim_index, split_name, config_dict in sim_configs: frames = random.randint(args.min_frames, args.max_frames) - sim_seed = base_seed + sim_index + tasks.append((sim_index, split_name, args.resolution, frames, output_base_dir, blend_dir, config_dict)) - rand_val = random.random() - if rand_val < COLLIDER_MODE_SIMPLE_THRESHOLD: - collider_mode = "simple" - elif rand_val < COLLIDER_MODE_MEDIUM_THRESHOLD: - collider_mode = "medium" - else: - collider_mode = "complex" - - # Determine if simulation should have no emitters - no_emitters = random.random() < NO_EMITTER_PCT - no_colliders = False - if no_emitters: - # Of no-emitter sims, NO_COLLIDER_PCT have no colliders (empty) - no_colliders = random.random() < NO_COLLIDER_PCT - - tasks.append( - ( - sim_index, - args.resolution, - frames, - output_base_dir, - blend_dir, - sim_seed, - collider_mode, - no_emitters, - no_colliders, - ) - ) + successful = 0 + failed = 0 + skipped = 0 + start_time = time.time() with ProcessPoolExecutor(max_workers=args.workers) as executor: future_to_task = {executor.submit(worker_wrapper, task): task for task in tasks} diff --git a/vdb-tools/render_npz_field_to_png.py b/vdb-tools/render_npz_field_to_png.py index 540310c..2711466 100644 --- a/vdb-tools/render_npz_field_to_png.py +++ b/vdb-tools/render_npz_field_to_png.py @@ -4,11 +4,10 @@ import numpy as np from PIL import Image -from config import _PROJECT_ROOT, project_config +from config import _PROJECT_ROOT, project_config, simulation_config AVAILABLE_FIELDS = ["density", "velx", "velz", "vel_magnitude", "emitter", "collider"] -INPUT_NPZ = _PROJECT_ROOT / project_config.vdb_tools.npz_output_directory / "128" OUTPUT_DIR = _PROJECT_ROOT / "data/npz_image_debug/" @@ -174,43 +173,60 @@ def main() -> None: choices=AVAILABLE_FIELDS + ["all"], help=f"Field to render: {', '.join(AVAILABLE_FIELDS)}, or 'all' (default: density)", ) + parser.add_argument( + "--split", + type=str, + default=None, + choices=[None] + simulation_config.splits.names, + help="Split to render ('train', 'val', or 'test'). If not specified, processes all splits (default: all)", + ) + parser.add_argument( + "--resolution", + type=str, + default="128", + help="Resolution directory to use (default: 128)", + ) parser.add_argument("--scale", type=int, default=4, help="Upscale factor for output images (default: 4)") args = parser.parse_args() - input_path = Path(INPUT_NPZ) - output_path = Path(OUTPUT_DIR) - _ensure_dir(output_path) - - processed = 0 + splits = simulation_config.splits.names if args.split is None else [args.split] render_all = args.field == "all" + processed = 0 - if input_path.is_dir(): - files: list[Path] = sorted( - [f for f in input_path.iterdir() if f.name.startswith("seq_") and f.name.endswith(".npz")] - ) - if not files: - raise FileNotFoundError(f"No seq_*.npz found in {args.input}") - for fp in files: - if render_all: - print(f"Rendering all fields from {fp}") - render_npz_all_fields(fp, output_path, scale=args.scale) - else: - n = render_npz_field(fp, output_path, field=args.field, scale=args.scale) - print(f"Rendered {n} {args.field} frames from {fp}") - processed += 1 - else: - if not input_path.is_file(): - raise FileNotFoundError(args.input) - if render_all: - print(f"Rendering all fields from {input_path}") - render_npz_all_fields(input_path, output_path, scale=args.scale) - else: - n = render_npz_field(input_path, output_path, field=args.field, scale=args.scale) - print(f"Rendered {n} {args.field} frames from {input_path}") - processed = 1 + for split_name in splits: + input_path = _PROJECT_ROOT / project_config.vdb_tools.npz_output_directory / args.resolution / split_name + if not input_path.exists(): + print(f"Warning: Split directory not found, skipping: {input_path}") + continue + + output_path = Path(OUTPUT_DIR) / split_name + _ensure_dir(output_path) + + print(f"\n{'=' * 70}") + print(f"Processing split: {split_name}") + print(f"{'=' * 70}") + + if input_path.is_dir(): + files: list[Path] = sorted( + [f for f in input_path.iterdir() if f.name.startswith("seq_") and f.name.endswith(".npz")] + ) + if not files: + print(f"Warning: No seq_*.npz found in {input_path}, skipping") + continue + for fp in files: + if render_all: + print(f"Rendering all fields from {fp.name}") + render_npz_all_fields(fp, output_path, scale=args.scale) + else: + n = render_npz_field(fp, output_path, field=args.field, scale=args.scale) + print(f"Rendered {n} {args.field} frames from {fp.name}") + processed += 1 field_desc = "all fields" if render_all else f"field: {args.field}" - print(f"Done. Processed {processed} sequence(s). {field_desc}. Output: {OUTPUT_DIR}") + splits_desc = "all splits" if args.split is None else f"split '{args.split}'" + print(f"\n{'=' * 70}") + print(f"Done. Processed {processed} sequence(s) from {splits_desc}. {field_desc}. Output: {OUTPUT_DIR}") + print(f"{'=' * 70}") if __name__ == "__main__": diff --git a/vdb-tools/vdb_core/batch_processing.py b/vdb-tools/vdb_core/batch_processing.py index 70f0cd5..9544424 100644 --- a/vdb-tools/vdb_core/batch_processing.py +++ b/vdb-tools/vdb_core/batch_processing.py @@ -429,14 +429,14 @@ def process_all_cache_sequences( save_frames: bool = False, percentiles: list[int] | None = None, normalization_percentile: int = 95, - stats_output_file: str = "data/_field_stats.yaml", + stats_output_file: str | None = "data/_field_stats.yaml", num_workers: int = 1, -) -> None: +) -> tuple[list, dict]: cache_data_dirs = discover_cache_sequences(blender_caches_root) if not cache_data_dirs: print(f"No cache sequences found in {blender_caches_root}") - return + return ([], {}) print(f"Found {len(cache_data_dirs)} cache sequences to process:") for cache_data_dir in cache_data_dirs: @@ -507,8 +507,8 @@ def process_all_cache_sequences( total_sequences += sequences_from_cache all_sequence_stats.extend(cache_stats) - # Compute and save global statistics - if all_sequence_stats: + # Compute and optionally save global statistics + if all_sequence_stats and stats_output_file is not None: try: global_stats = aggregate_global_stats(all_sequence_stats, percentiles=percentiles_to_use) normalization_scales = compute_normalization_scales(global_stats, normalization_percentile) @@ -537,9 +537,12 @@ def process_all_cache_sequences( except Exception as e: print(f"\nError: Failed to save statistics: {e}") print("This is non-critical; NPZ files were created successfully.") - else: + elif not all_sequence_stats: print("\nWarning: No sequences were created, no statistics computed.") print(f"\n{'=' * 60}") print(f"COMPLETE: Generated {total_sequences} total sequences in {output_dir}") print(f"{'=' * 60}") + + # Return stats and metadata for caller aggregation + return (all_sequence_stats, all_mesh_metadata) diff --git a/vdb-tools/vdb_to_numpy.py b/vdb-tools/vdb_to_numpy.py index 267e224..cbf78f1 100644 --- a/vdb-tools/vdb_to_numpy.py +++ b/vdb-tools/vdb_to_numpy.py @@ -1,9 +1,14 @@ import argparse import shutil -from vdb_core.batch_processing import process_all_cache_sequences +from vdb_core.batch_processing import ( + aggregate_global_stats, + compute_normalization_scales, + process_all_cache_sequences, + save_stats_to_yaml, +) -from config import PROJECT_ROOT_PATH, project_config +from config import PROJECT_ROOT_PATH, project_config, simulation_config def main() -> None: @@ -27,6 +32,12 @@ def main() -> None: default="all", help="Resolution to process (64, 128, 256, etc.) or 'all' to process all resolutions (default: all)", ) + parser.add_argument( + "--split", + type=str, + default=None, + help="Specific split to process ('train', 'val', or 'test'). If not specified, processes all splits.", + ) parser.add_argument( "--workers", type=int, @@ -52,36 +63,89 @@ def main() -> None: else: resolutions = [args.resolution] + if args.split is None: + splits = simulation_config.splits.names + else: + splits = [args.split] + for resolution_str in resolutions: resolution = int(resolution_str) - cache_dir = blender_caches_root / resolution_str - output_dir = npz_output_root / f"{resolution_str}" - - if not cache_dir.exists(): - print(f"Warning: Resolution directory not found, skipping: {cache_dir}") - continue - - print(f"\n{'=' * 70}") - print(f"Processing resolution: {resolution}x{resolution}") - print(f"{'=' * 70}") - print(f"Input directory: {cache_dir}") - print(f"Output directory: {output_dir}") - - if output_dir.exists(): - print(f"Clearing output directory: {output_dir}") - shutil.rmtree(output_dir) - - process_all_cache_sequences( - blender_caches_root=cache_dir, - output_dir=output_dir, - target_resolution=resolution, - max_frames=args.max_frames, - save_frames=args.save_frames, - percentiles=project_config.vdb_tools.stats_percentiles, - normalization_percentile=project_config.vdb_tools.normalization_percentile, - stats_output_file=project_config.vdb_tools.stats_output_file, - num_workers=args.workers, - ) + + # Collect stats from all splits for aggregation + # NOTE: Stats computed across all splits (not just train) to ensure proper value + # clamping to [0,1] or [-1,1]. Deviates from standard ML practice but required + # for correct normalization across full dataset. + # The goal is autoregressive rollout, not generalization on unseen data. + all_splits_stats: list = [] + all_splits_mesh_metadata: dict = {} + + for split_name in splits: + cache_dir = blender_caches_root / resolution_str / split_name + output_dir = npz_output_root / resolution_str / split_name + + if not cache_dir.exists(): + print(f"Warning: Split directory not found, skipping: {cache_dir}") + continue + + print(f"\n{'=' * 70}") + print(f"Processing: {resolution}x{resolution}, split: {split_name}") + print(f"{'=' * 70}") + print(f"Input: {cache_dir}") + print(f"Output: {output_dir}") + + if output_dir.exists(): + print(f"Clearing output directory: {output_dir}") + shutil.rmtree(output_dir) + + # Process split without saving stats (we'll aggregate all splits first) + split_stats, split_mesh_metadata = process_all_cache_sequences( + blender_caches_root=cache_dir, + output_dir=output_dir, + target_resolution=resolution, + max_frames=args.max_frames, + save_frames=args.save_frames, + percentiles=project_config.vdb_tools.stats_percentiles, + normalization_percentile=project_config.vdb_tools.normalization_percentile, + stats_output_file=None, # Don't save yet + num_workers=args.workers, + ) + + all_splits_stats.extend(split_stats) + all_splits_mesh_metadata.update(split_mesh_metadata) + + # Aggregate stats across all splits and save once + if all_splits_stats: + print(f"\n{'=' * 70}") + print(f"Aggregating statistics across all splits for resolution {resolution}...") + print(f"{'=' * 70}") + + global_stats = aggregate_global_stats( + all_splits_stats, percentiles=project_config.vdb_tools.stats_percentiles + ) + normalization_scales = compute_normalization_scales( + global_stats, project_config.vdb_tools.normalization_percentile + ) + stats_output_path = PROJECT_ROOT_PATH / project_config.vdb_tools.stats_output_file + + save_stats_to_yaml( + output_path=stats_output_path, + sequence_stats=all_splits_stats, + global_stats=global_stats, + normalization_scales=normalization_scales, + mesh_metadata=all_splits_mesh_metadata, + ) + + print("\nSTATISTICS SUMMARY (aggregated across all splits):") + print(f" Total sequences analyzed: {global_stats.num_sequences}") + print(f" Density range: [{global_stats.density.min:.6f}, {global_stats.density.max:.6f}]") + print(f" Velx range: [{global_stats.velx.min:.6f}, {global_stats.velx.max:.6f}]") + print(f" Velz range: [{global_stats.velz.min:.6f}, {global_stats.velz.max:.6f}]") + print(f"\n Normalization scales (P{project_config.vdb_tools.normalization_percentile}):") + print(f" S_density = {normalization_scales.S_density:.6f}") + print(f" S_velx = {normalization_scales.S_velx:.6f}") + print(f" S_velz = {normalization_scales.S_velz:.6f}") + print(f"\n Saved aggregated statistics to: {stats_output_path}") + print(f"{'=' * 70}") print(f"\n{'=' * 70}") print("All resolutions processed!")