From 0da59f433365effed28aac1d1d0e034c1db5ec19 Mon Sep 17 00:00:00 2001 From: ohmatheus Date: Thu, 19 Feb 2026 11:12:40 +0100 Subject: [PATCH 01/10] Added a vorticity range for data gen --- vdb-tools/config.py | 25 +++++++++++++++++++++---- vdb-tools/config/simulation_config.yaml | 4 +++- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/vdb-tools/config.py b/vdb-tools/config.py index 5159999..e18ece6 100644 --- a/vdb-tools/config.py +++ b/vdb-tools/config.py @@ -66,19 +66,27 @@ class EmitterConfig(BaseModel): class ColliderModeConfig(BaseModel): count_range: tuple[int, int] - scale: ScaleConfig | None = None + + +class ColliderScaledModeConfig(ColliderModeConfig): + scale: ScaleConfig class ColliderConfig(BaseModel): - simple_mode: ColliderModeConfig + simple_mode: ColliderScaledModeConfig medium_mode: ColliderModeConfig - complex_mode: ColliderModeConfig + complex_mode: ColliderScaledModeConfig position: ColliderPositionConfig +class VorticityConfig(BaseModel): + range: float | list[float] + step: float = 0.1 + + class DomainConfig(BaseModel): y_scale: float = 0.05 - vorticity: float = 0.05 + vorticity: VorticityConfig = VorticityConfig(range=0.05) beta: float = 0.0 @@ -95,6 +103,13 @@ class SimulationGenerationConfig(BaseModel): animation: AnimationConfig = AnimationConfig() +def get_vorticity_levels(v: VorticityConfig) -> list[float]: + if isinstance(v.range, (int, float)): + return [float(v.range)] + import numpy as np + return np.arange(v.range[0], v.range[1] + v.step / 2, v.step).tolist() + + def load_simulation_config() -> SimulationGenerationConfig: config_path = Path(__file__).parent / "config" / "simulation_config.yaml" if not config_path.exists(): @@ -120,4 +135,6 @@ class VDBSettings(BaseSettings): "project_config", "vdb_config", "simulation_config", + "VorticityConfig", + "get_vorticity_levels", ] diff --git a/vdb-tools/config/simulation_config.yaml b/vdb-tools/config/simulation_config.yaml index cb212c7..a94220d 100644 --- a/vdb-tools/config/simulation_config.yaml +++ b/vdb-tools/config/simulation_config.yaml @@ -45,7 +45,9 @@ simulation_generation: domain: y_scale: 0.05 - vorticity: 0.05 + vorticity: + range: [0.1, 0.4] + step: 0.1 beta: 0.0 animation: From 212ca8e2c2486b1e943b9bf1c91a3bd48c13f57b Mon Sep 17 00:00:00 2001 From: ohmatheus Date: Thu, 19 Feb 2026 11:25:48 +0100 Subject: [PATCH 02/10] Added vorticity range to generated simulations. Output simulation's vorticity in a meta file. Empty/non/empty sequences and splits should have same vorticity distribution. --- vdb-tools/config.py | 1 + vdb-tools/create_simulations.py | 87 +++++++++++++++++++++------------ 2 files changed, 56 insertions(+), 32 deletions(-) diff --git a/vdb-tools/config.py b/vdb-tools/config.py index e18ece6..30cc182 100644 --- a/vdb-tools/config.py +++ b/vdb-tools/config.py @@ -107,6 +107,7 @@ def get_vorticity_levels(v: VorticityConfig) -> list[float]: if isinstance(v.range, (int, float)): return [float(v.range)] import numpy as np + return np.arange(v.range[0], v.range[1] + v.step / 2, v.step).tolist() diff --git a/vdb-tools/create_simulations.py b/vdb-tools/create_simulations.py index e8b0ba9..02d36f3 100755 --- a/vdb-tools/create_simulations.py +++ b/vdb-tools/create_simulations.py @@ -10,7 +10,7 @@ from math import ceil from pathlib import Path -from config import PROJECT_ROOT_PATH, SimulationGenerationConfig, simulation_config, vdb_config +from config import PROJECT_ROOT_PATH, SimulationGenerationConfig, get_vorticity_levels, simulation_config, vdb_config BLENDER_SCRIPT = Path(__file__).parent / "blender_scripts/create_random_simulation.py" @@ -24,28 +24,36 @@ class SplitPlan: collider_simple_count: int collider_medium_count: int collider_complex_count: int + vorticity_levels: list[float] def compute_split_plan(split_count: int, split_name: str, gen_config: SimulationGenerationConfig) -> SplitPlan: - no_emitter_count = max(1, ceil(split_count * gen_config.distribution.no_emitter_pct)) - no_collider_count = max(1, ceil(no_emitter_count * gen_config.distribution.no_collider_pct)) + levels = get_vorticity_levels(gen_config.domain.vorticity) + n_levels = len(levels) - sims_with_emitters = split_count - no_emitter_count simple_thresh = gen_config.distribution.collider_mode_simple_threshold medium_thresh = gen_config.distribution.collider_mode_medium_threshold - collider_simple = max(1, ceil(sims_with_emitters * simple_thresh)) - collider_medium = max(1, ceil(sims_with_emitters * (medium_thresh - simple_thresh))) - collider_complex = max(0, sims_with_emitters - collider_simple - collider_medium) + raw_no_emitter = max(1, ceil(split_count * gen_config.distribution.no_emitter_pct)) + no_emitter_count = ceil(raw_no_emitter / n_levels) * n_levels + no_collider_count = max(1, ceil(raw_no_emitter * gen_config.distribution.no_collider_pct)) + + raw_fluid = split_count - raw_no_emitter + n_fluid_total = ceil(max(1, raw_fluid) / n_levels) * n_levels + + collider_simple = max(1, ceil(n_fluid_total * simple_thresh)) + collider_medium = max(1, ceil(n_fluid_total * (medium_thresh - simple_thresh))) + collider_complex = max(0, n_fluid_total - collider_simple - collider_medium) return SplitPlan( split_name=split_name, - total_count=split_count, + total_count=no_emitter_count + n_fluid_total, no_emitter_count=no_emitter_count, no_collider_count=no_collider_count, collider_simple_count=collider_simple, collider_medium_count=collider_medium, collider_complex_count=collider_complex, + vorticity_levels=levels, ) @@ -68,7 +76,12 @@ def assign_simulations_to_splits(total_sims: int, gen_config: SimulationGenerati def pack_config( - gen_config: SimulationGenerationConfig, sim_index: int, base_seed: int, split_name: str, sim_type: dict + gen_config: SimulationGenerationConfig, + sim_index: int, + base_seed: int, + split_name: str, + sim_type: dict, + domain_vorticity: float, ) -> dict: em = gen_config.emitters col = gen_config.colliders @@ -101,7 +114,7 @@ def pack_config( "collider_z_range": list(col.position.z_range), # Domain & animation "domain_y_scale": gen_config.domain.y_scale, - "domain_vorticity": gen_config.domain.vorticity, + "domain_vorticity": domain_vorticity, "domain_beta": gen_config.domain.beta, "anim_max_displacement": gen_config.animation.max_displacement, } @@ -115,28 +128,36 @@ def generate_simulation_configs( sim_index = start_index for split_name, plan in split_plans.items(): - sim_types = [] - - for i in range(plan.no_emitter_count): - sim_types.append( - { - "collider_mode": None, - "no_emitters": True, - "no_colliders": (i < plan.no_collider_count), - } - ) - - for _ in range(plan.collider_simple_count): - sim_types.append({"collider_mode": "simple", "no_emitters": False, "no_colliders": False}) - for _ in range(plan.collider_medium_count): - sim_types.append({"collider_mode": "medium", "no_emitters": False, "no_colliders": False}) - for _ in range(plan.collider_complex_count): - sim_types.append({"collider_mode": "complex", "no_emitters": False, "no_colliders": False}) - - rng.shuffle(sim_types) - - for sim_type in sim_types: - config = pack_config(gen_config, sim_index, base_seed, split_name, sim_type) + # Vorticity levels are pre-computed in SplitPlan so that group sizes are exact + # multiples of n_levels — each level appears exactly n_per_level times per group. + levels = plan.vorticity_levels + + # Empty-scene group: no fluid emitters, optionally no colliders either. + # First no_collider_count entries are fully empty (no emitters, no colliders). + empty_types = [ + {"collider_mode": None, "no_emitters": True, "no_colliders": (i < plan.no_collider_count)} + for i in range(plan.no_emitter_count) + ] + + # Fluid group: build the full list of collider-mode labels, then shuffle so + # collider complexity is randomised independently of vorticity level. + fluid_types = ( + [{"collider_mode": "simple", "no_emitters": False, "no_colliders": False}] * plan.collider_simple_count + + [{"collider_mode": "medium", "no_emitters": False, "no_colliders": False}] * plan.collider_medium_count + + [{"collider_mode": "complex", "no_emitters": False, "no_colliders": False}] * plan.collider_complex_count + ) + rng.shuffle(fluid_types) + + # Assign vorticity by cycling through levels + for i, sim_type in enumerate(empty_types): + vorticity = levels[i % len(levels)] + config = pack_config(gen_config, sim_index, base_seed, split_name, sim_type, vorticity) + all_configs.append((sim_index, split_name, config)) + sim_index += 1 + + for j, sim_type in enumerate(fluid_types): + vorticity = levels[j % len(levels)] + config = pack_config(gen_config, sim_index, base_seed, split_name, sim_type, vorticity) all_configs.append((sim_index, split_name, config)) sim_index += 1 @@ -206,6 +227,8 @@ def generate_simulation( if result.returncode == 0: if check_cache_exists(cache_dir): + meta = {"vorticity": config_dict["domain_vorticity"], "split": split_name} + (cache_dir / "meta.json").write_text(json.dumps(meta)) return True, "success" else: print(" Error: Blender succeeded but no VDB files found") From b9a4be1e5a6d96bad2da4ef8fbbcc2b36dca930d Mon Sep 17 00:00:00 2001 From: ohmatheus Date: Thu, 19 Feb 2026 11:28:49 +0100 Subject: [PATCH 03/10] updated npz conversion with meta file 'passing' to npz folder --- vdb-tools/vdb_core/batch_processing.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vdb-tools/vdb_core/batch_processing.py b/vdb-tools/vdb_core/batch_processing.py index 9544424..8e0de97 100644 --- a/vdb-tools/vdb_core/batch_processing.py +++ b/vdb-tools/vdb_core/batch_processing.py @@ -1,3 +1,4 @@ +import json import re from multiprocessing import Pool from pathlib import Path @@ -182,6 +183,15 @@ def process_single_cache_sequence( ) -> tuple[int, list[SequenceStats]]: cache_data_dir = Path(cache_data_dir) output_dir = Path(output_dir) + + meta_path = cache_data_dir.parent / "meta.json" + cache_vorticity: float | None = None + if meta_path.exists(): + with open(meta_path) as f: + cache_vorticity = json.load(f).get("vorticity") + else: + print(f"Warning: meta.json missing for {cache_data_dir.parent.name}, vorticity will be null") + vdb_files = sorted(cache_data_dir.glob("*.vdb")) if not vdb_files: @@ -360,6 +370,9 @@ def process_single_cache_sequence( collider=c_stack, ) print(f"Saved {out_path} shape: T={T}, H={H}, W={W} (density, velx, velz, emitter, collider)") + meta_out = output_dir / f"seq_{global_seq_num:04d}.meta.json" + with open(meta_out, "w") as f: + json.dump({"vorticity": cache_vorticity}, f) # Compute statistics for this sequence try: From ecbd7a7bf0a0b7903e7638d12369012bd929f66a Mon Sep 17 00:00:00 2001 From: ohmatheus Date: Thu, 19 Feb 2026 13:29:23 +0100 Subject: [PATCH 04/10] Updated dataset class with scalar --- ml/dataset/npz_sequence.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/ml/dataset/npz_sequence.py b/ml/dataset/npz_sequence.py index d8727c7..5a6607f 100644 --- a/ml/dataset/npz_sequence.py +++ b/ml/dataset/npz_sequence.py @@ -1,3 +1,4 @@ +import json import time from pathlib import Path @@ -253,7 +254,14 @@ def __init__( raise FileNotFoundError(f"No seq_*.npz files found in {npz_dir_path}") self.num_real_sequences = len(self.seq_paths) - # self.num_fake_sequences = _calculate_fake_count(self.num_real_sequences, fake_empty_pct) + + self._seq_scalars: list[float | None] = [] + for path in self.seq_paths: + meta_path = path.with_name(path.stem + ".meta.json") + if meta_path.exists(): + self._seq_scalars.append(json.loads(meta_path.read_text()).get("vorticity")) + else: + self._seq_scalars.append(None) # Load global normalization scales self._norm_scales: dict[str, float] | None = None @@ -437,14 +445,14 @@ def __len__(self) -> int: def __getitem__( self, idx: int - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: si, t = self._index[idx] - # Handle real sequences path = self.seq_paths[si] + cond_val = self._seq_scalars[si] + cond = torch.tensor(cond_val if cond_val is not None else 0.0, dtype=torch.float32) if self.rollout_steps > 1: - # Multi-step rollout mode if self.preload and self._preloaded_sequences is not None: x, y_seq, masks = self._load_rollout_sample_from_memory(si, t) else: @@ -455,9 +463,8 @@ def __getitem__( x, y_seq, masks = apply_rollout_augmentation(x, y_seq, masks, self.flip_probability) - return x, y_seq, masks + return x, y_seq, masks, cond else: - # Single-step mode (existing code) if self.preload and self._preloaded_sequences is not None: x, y = self._load_sample_from_memory(si, t) else: @@ -466,4 +473,4 @@ def __getitem__( if self.is_training and self.enable_augmentation: x, y = apply_augmentation(x, y, self.flip_probability) - return x, y + return x, y, cond From be0f6e609ff8f1f65fbc3507f8805396ee57b6fb Mon Sep 17 00:00:00 2001 From: ohmatheus Date: Thu, 19 Feb 2026 13:43:26 +0100 Subject: [PATCH 05/10] Updated Unet architecture with Film conditionning/embedding and FiLM layer --- .../model_architectures/UNet_medium.yaml | 4 +- ml/config/training_config.py | 2 + ml/models/unet.py | 154 ++++++++++++------ ml/scripts/train.py | 4 + 4 files changed, 111 insertions(+), 53 deletions(-) diff --git a/ml/config/model_architectures/UNet_medium.yaml b/ml/config/model_architectures/UNet_medium.yaml index b025289..b849a15 100644 --- a/ml/config/model_architectures/UNet_medium.yaml +++ b/ml/config/model_architectures/UNet_medium.yaml @@ -2,7 +2,7 @@ in_channels: 6 out_channels: 3 base_channels: 32 depth: 3 -norm: instance +norm: instance #group act: gelu group_norm_groups: 8 dropout: 0.1 @@ -12,3 +12,5 @@ padding_mode: replicate # "zeros", "reflect", "replicate", "circular" use_residual: true bottleneck_blocks: 1 output_activation: linear_clamp +use_film: true +film_cond_dim: 128 diff --git a/ml/config/training_config.py b/ml/config/training_config.py index a809c96..246338d 100644 --- a/ml/config/training_config.py +++ b/ml/config/training_config.py @@ -68,6 +68,8 @@ class TrainingConfig(BaseModel): use_residual: bool = True bottleneck_blocks: int = 1 output_activation: OutputActivationType = "linear_clamp" + use_film: bool = False + film_cond_dim: int = 128 # MLFlow settings mlflow_tracking_uri: str = "./mlruns" # mlflow server diff --git a/ml/models/unet.py b/ml/models/unet.py index e9a63bc..b6c70a7 100644 --- a/ml/models/unet.py +++ b/ml/models/unet.py @@ -14,18 +14,18 @@ OutputActivationType = Literal["sigmoid_tanh", "linear_clamp"] -def _norm(norm: NormType, ch: int, groups: int) -> nn.Module: +def _norm(norm: NormType, ch: int, groups: int, affine: bool = True) -> nn.Module: if norm == "none": return nn.Identity() if norm == "batch": return nn.BatchNorm2d(ch) if norm == "instance": - return nn.InstanceNorm2d(ch, affine=True) + return nn.InstanceNorm2d(ch, affine=affine) if norm == "group": g = min(groups, ch) while g > 1 and (ch % g) != 0: g -= 1 - return nn.GroupNorm(g, ch) + return nn.GroupNorm(g, ch, affine=affine) def _act(act: ActType) -> nn.Module: @@ -39,6 +39,29 @@ def _act(act: ActType) -> nn.Module: return nn.SiLU(inplace=True) +class ConditioningEncoder(nn.Module): + def __init__(self, cond_dim: int) -> None: + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(1, cond_dim), + nn.SiLU(), + nn.Linear(cond_dim, cond_dim), + ) + + def forward(self, c: torch.Tensor) -> torch.Tensor: + return self.mlp(c.unsqueeze(-1)) # (B,) → (B,1) → (B, cond_dim) + + +class FiLMLayer(nn.Module): + def __init__(self, ch: int, cond_dim: int) -> None: + super().__init__() + self.proj = nn.Linear(cond_dim, 2 * ch) + + def forward(self, x: torch.Tensor, cond_emb: torch.Tensor) -> torch.Tensor: + gamma, beta = self.proj(cond_emb).chunk(2, dim=-1) # each (B, C) + return gamma.unsqueeze(-1).unsqueeze(-1) * x + beta.unsqueeze(-1).unsqueeze(-1) + + class ConvBlock(nn.Module): def __init__( self, @@ -50,22 +73,33 @@ def __init__( groups: int, dropout: float, padding_mode: PaddingType = "zeros", + film_cond_dim: int = 0, ) -> None: super().__init__() + use_film = film_cond_dim > 0 self.c1 = nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode=padding_mode, bias=(norm == "none")) - self.n1 = _norm(norm, out_ch, groups) + self.n1 = _norm(norm, out_ch, groups, affine=not use_film) self.a1 = _act(act) self.c2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode=padding_mode, bias=(norm == "none")) - self.n2 = _norm(norm, out_ch, groups) + self.n2 = _norm(norm, out_ch, groups, affine=not use_film) self.a2 = _act(act) self.drop = nn.Dropout2d(dropout) if dropout > 0.0 else nn.Identity() - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.a1(self.n1(self.c1(x))) + self.film1 = FiLMLayer(out_ch, film_cond_dim) if use_film else None + self.film2 = FiLMLayer(out_ch, film_cond_dim) if use_film else None + + def forward(self, x: torch.Tensor, cond_emb: torch.Tensor | None = None) -> torch.Tensor: + x = self.n1(self.c1(x)) + if self.film1 is not None and cond_emb is not None: + x = self.film1(x, cond_emb) + x = self.a1(x) x = self.drop(x) - x = self.a2(self.n2(self.c2(x))) + x = self.n2(self.c2(x)) + if self.film2 is not None and cond_emb is not None: + x = self.film2(x, cond_emb) + x = self.a2(x) return x @@ -79,12 +113,16 @@ def __init__( groups: int, dropout: float, padding_mode: PaddingType = "zeros", + film_cond_dim: int = 0, ) -> None: super().__init__() - self.b = ConvBlock(ch, ch, norm=norm, act=act, groups=groups, dropout=dropout, padding_mode=padding_mode) + self.b = ConvBlock( + ch, ch, norm=norm, act=act, groups=groups, dropout=dropout, + padding_mode=padding_mode, film_cond_dim=film_cond_dim, + ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return cast("torch.Tensor", x + self.b(x)) + def forward(self, x: torch.Tensor, cond_emb: torch.Tensor | None = None) -> torch.Tensor: + return cast("torch.Tensor", x + self.b(x, cond_emb)) def _downsample(mode: DownsampleType, ch: int, padding_mode: PaddingType) -> nn.Module: @@ -107,21 +145,27 @@ def __init__( use_residual: bool, downsample: DownsampleType = "stride", padding_mode: PaddingType = "zeros", + film_cond_dim: int = 0, ) -> None: super().__init__() self.block = ConvBlock( - in_ch, out_ch, norm=norm, act=act, groups=groups, dropout=dropout, padding_mode=padding_mode + in_ch, out_ch, norm=norm, act=act, groups=groups, dropout=dropout, + padding_mode=padding_mode, film_cond_dim=film_cond_dim, ) self.res = ( - ResBlock(out_ch, norm=norm, act=act, groups=groups, dropout=dropout, padding_mode=padding_mode) + ResBlock(out_ch, norm=norm, act=act, groups=groups, dropout=dropout, + padding_mode=padding_mode, film_cond_dim=film_cond_dim) if use_residual else nn.Identity() ) self.down = _downsample(downsample, out_ch, padding_mode) - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - x = self.block(x) - x = self.res(x) + def forward(self, x: torch.Tensor, cond_emb: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]: + x = self.block(x, cond_emb) + if isinstance(self.res, ResBlock): + x = self.res(x, cond_emb) + else: + x = self.res(x) skip = x x = self.down(x) return x, skip @@ -141,10 +185,10 @@ def __init__( dropout: float, use_residual: bool, padding_mode: PaddingType = "zeros", + film_cond_dim: int = 0, ) -> None: super().__init__() - # Important for mypy: this attribute must accept both Upsample and ConvTranspose2d self.up: nn.Module if upsample == "transpose": self.up = nn.ConvTranspose2d(in_ch, in_ch, kernel_size=2, stride=2) @@ -156,18 +200,19 @@ def __init__( ) self.block = ConvBlock( - in_ch + skip_ch, out_ch, norm=norm, act=act, groups=groups, dropout=dropout, padding_mode=padding_mode + in_ch + skip_ch, out_ch, norm=norm, act=act, groups=groups, dropout=dropout, + padding_mode=padding_mode, film_cond_dim=film_cond_dim, ) self.res = ( - ResBlock(out_ch, norm=norm, act=act, groups=groups, dropout=dropout, padding_mode=padding_mode) + ResBlock(out_ch, norm=norm, act=act, groups=groups, dropout=dropout, + padding_mode=padding_mode, film_cond_dim=film_cond_dim) if use_residual else nn.Identity() ) - def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, skip: torch.Tensor, cond_emb: torch.Tensor | None = None) -> torch.Tensor: x = self.up(x) - # If shapes don't match (odd sizes), crop skip to match x (center-ish). if x.shape[-2:] != skip.shape[-2:]: dh = skip.shape[-2] - x.shape[-2] dw = skip.shape[-1] - x.shape[-1] @@ -175,8 +220,11 @@ def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: skip = skip[..., dh // 2 : dh // 2 + x.shape[-2], dw // 2 : dw // 2 + x.shape[-1]] x = torch.cat([x, skip], dim=1) - x = self.block(x) - x = self.res(x) + x = self.block(x, cond_emb) + if isinstance(self.res, ResBlock): + x = self.res(x, cond_emb) + else: + x = self.res(x) return x @@ -200,6 +248,9 @@ class UNetConfig: bottleneck_blocks: int = 1 output_activation: OutputActivationType = "linear_clamp" + use_film: bool = False + film_cond_dim: int = 128 + class UNet(nn.Module): """ @@ -215,6 +266,11 @@ def __init__(self, cfg: UNetConfig | None = None) -> None: if self.cfg.depth < 1: raise ValueError("depth must be >= 1") + film_cond_dim = self.cfg.film_cond_dim if self.cfg.use_film else 0 + + if self.cfg.use_film: + self.cond_encoder = ConditioningEncoder(self.cfg.film_cond_dim) + self.stem = nn.Conv2d( self.cfg.in_channels, self.cfg.base_channels, 3, padding=1, padding_mode=self.cfg.padding_mode ) @@ -226,15 +282,11 @@ def __init__(self, cfg: UNetConfig | None = None) -> None: out_ch = ch * 2 downs.append( Down( - ch, - out_ch, - norm=self.cfg.norm, - act=self.cfg.act, - groups=self.cfg.group_norm_groups, - dropout=self.cfg.dropout, - use_residual=self.cfg.use_residual, - downsample=self.cfg.downsample, - padding_mode=self.cfg.padding_mode, + ch, out_ch, + norm=self.cfg.norm, act=self.cfg.act, groups=self.cfg.group_norm_groups, + dropout=self.cfg.dropout, use_residual=self.cfg.use_residual, + downsample=self.cfg.downsample, padding_mode=self.cfg.padding_mode, + film_cond_dim=film_cond_dim, ) ) skip_chs.append(out_ch) @@ -246,14 +298,12 @@ def __init__(self, cfg: UNetConfig | None = None) -> None: mids.append( ResBlock( ch, - norm=self.cfg.norm, - act=self.cfg.act, - groups=self.cfg.group_norm_groups, - dropout=self.cfg.dropout, - padding_mode=self.cfg.padding_mode, + norm=self.cfg.norm, act=self.cfg.act, groups=self.cfg.group_norm_groups, + dropout=self.cfg.dropout, padding_mode=self.cfg.padding_mode, + film_cond_dim=film_cond_dim, ) ) - self.mid = nn.Sequential(*mids) if mids else nn.Identity() + self.mid = nn.ModuleList(mids) ups: list[nn.Module] = [] for i in reversed(range(self.cfg.depth)): @@ -261,16 +311,11 @@ def __init__(self, cfg: UNetConfig | None = None) -> None: out_ch = skip_ch // 2 ups.append( Up( - ch, - skip_ch, - out_ch, - upsample=self.cfg.upsample, - norm=self.cfg.norm, - act=self.cfg.act, - groups=self.cfg.group_norm_groups, - dropout=self.cfg.dropout, - use_residual=self.cfg.use_residual, - padding_mode=self.cfg.padding_mode, + ch, skip_ch, out_ch, + upsample=self.cfg.upsample, norm=self.cfg.norm, act=self.cfg.act, + groups=self.cfg.group_norm_groups, dropout=self.cfg.dropout, + use_residual=self.cfg.use_residual, padding_mode=self.cfg.padding_mode, + film_cond_dim=film_cond_dim, ) ) ch = out_ch @@ -278,18 +323,23 @@ def __init__(self, cfg: UNetConfig | None = None) -> None: self.head = nn.Conv2d(ch, self.cfg.out_channels, 1) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, cond_scalar: torch.Tensor | None = None) -> torch.Tensor: + cond_emb: torch.Tensor | None = None + if self.cfg.use_film and cond_scalar is not None: + cond_emb = self.cond_encoder(cond_scalar) + x = self.stem(x) skips: list[torch.Tensor] = [] for down in self.downs: - x, s = cast("Down", down)(x) + x, s = cast("Down", down)(x, cond_emb) skips.append(s) - x = self.mid(x) + for block in self.mid: + x = cast("ResBlock", block)(x, cond_emb) for up, skip in zip(self.ups, reversed(skips), strict=True): - x = cast("Up", up)(x, skip) + x = cast("Up", up)(x, skip, cond_emb) x = self.head(x) diff --git a/ml/scripts/train.py b/ml/scripts/train.py index fa60838..4560152 100644 --- a/ml/scripts/train.py +++ b/ml/scripts/train.py @@ -126,6 +126,8 @@ def train_single_variant( use_residual=config.use_residual, bottleneck_blocks=config.bottleneck_blocks, output_activation=config.output_activation, + use_film=config.use_film, + film_cond_dim=config.film_cond_dim, ) ).to(config.device) @@ -228,6 +230,8 @@ def train_single_variant( "use_residual": config.use_residual, "bottleneck_blocks": config.bottleneck_blocks, "output_activation": config.output_activation, + "use_film": config.use_film, + "film_cond_dim": config.film_cond_dim, # Checkpoint settings "save_every_n_epochs": config.save_every_n_epochs, "keep_last_n_checkpoints": config.keep_last_n_checkpoints, From 08971facaeef31954149f7ef02a0bd24e4c3b631 Mon Sep 17 00:00:00 2001 From: ohmatheus Date: Thu, 19 Feb 2026 13:51:42 +0100 Subject: [PATCH 06/10] updated trainer with FiLM conditioning --- ml/training/trainer.py | 47 ++++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/ml/training/trainer.py b/ml/training/trainer.py index 6b7aeb7..decc0e0 100644 --- a/ml/training/trainer.py +++ b/ml/training/trainer.py @@ -146,19 +146,20 @@ def train_epoch(self) -> dict[str, float]: pbar = tqdm(self.train_loader, desc="Training", leave=False) for batch_data in pbar: - # Detect rollout mode based on batch structure - if len(batch_data) == 3: - # Rollout mode: (x_0, y_seq, masks) - inputs, targets, masks = batch_data + if len(batch_data) == 4: + # Rollout mode: (x_0, y_seq, masks, cond) + inputs, targets, masks, cond = batch_data inputs = inputs.to(self.device, non_blocking=True) targets = targets.to(self.device, non_blocking=True) masks = masks.to(self.device, non_blocking=True) + cond = cond.to(self.device, non_blocking=True) rollout_mode = True else: - # Single-step mode: (x, y) - inputs, targets = batch_data + # Single-step mode: (x, y, cond) + inputs, targets, cond = batch_data inputs = inputs.to(self.device, non_blocking=True) targets = targets.to(self.device, non_blocking=True) + cond = cond.to(self.device, non_blocking=True) rollout_mode = False self.optimizer.zero_grad() @@ -167,9 +168,9 @@ def train_epoch(self) -> dict[str, float]: if self.scaler is not None: with autocast(device_type=self.device): if rollout_mode: - loss, loss_dict, _ = self._compute_rollout_loss(inputs, targets, masks) + loss, loss_dict, _ = self._compute_rollout_loss(inputs, targets, masks, cond) else: - outputs = self.model(inputs) + outputs = self.model(inputs, cond) loss, loss_dict = self.criterion(outputs, targets, inputs) self.scaler.scale(loss).backward() @@ -182,9 +183,9 @@ def train_epoch(self) -> dict[str, float]: self.scaler.update() else: if rollout_mode: - loss, loss_dict, _ = self._compute_rollout_loss(inputs, targets, masks) + loss, loss_dict, _ = self._compute_rollout_loss(inputs, targets, masks, cond) else: - outputs = self.model(inputs) + outputs = self.model(inputs, cond) loss, loss_dict = self.criterion(outputs, targets, inputs) loss.backward() @@ -216,34 +217,35 @@ def validate(self) -> dict[str, float]: with torch.no_grad(): pbar = tqdm(self.val_loader, desc="Validation", leave=False) for batch_data in pbar: - # Detect rollout mode based on batch structure - if len(batch_data) == 3: - # Rollout mode: (x_0, y_seq, masks) - inputs, targets, masks = batch_data + if len(batch_data) == 4: + # Rollout mode: (x_0, y_seq, masks, cond) + inputs, targets, masks, cond = batch_data inputs = inputs.to(self.device, non_blocking=True) targets = targets.to(self.device, non_blocking=True) masks = masks.to(self.device, non_blocking=True) + cond = cond.to(self.device, non_blocking=True) rollout_mode = True else: - # Single-step mode: (x, y) - inputs, targets = batch_data + # Single-step mode: (x, y, cond) + inputs, targets, cond = batch_data inputs = inputs.to(self.device, non_blocking=True) targets = targets.to(self.device, non_blocking=True) + cond = cond.to(self.device, non_blocking=True) rollout_mode = False # Compute loss (with or without AMP) if self.scaler is not None: with autocast(device_type=self.device): if rollout_mode: - loss, loss_dict, outputs = self._compute_rollout_loss(inputs, targets, masks) + loss, loss_dict, outputs = self._compute_rollout_loss(inputs, targets, masks, cond) else: - outputs = self.model(inputs) + outputs = self.model(inputs, cond) loss, loss_dict = self.criterion(outputs, targets, inputs) else: if rollout_mode: - loss, loss_dict, outputs = self._compute_rollout_loss(inputs, targets, masks) + loss, loss_dict, outputs = self._compute_rollout_loss(inputs, targets, masks, cond) else: - outputs = self.model(inputs) + outputs = self.model(inputs, cond) loss, loss_dict = self.criterion(outputs, targets, inputs) num_batches += 1 @@ -313,6 +315,7 @@ def _compute_rollout_loss( x_0: torch.Tensor, y_seq: torch.Tensor, masks: torch.Tensor, + cond: torch.Tensor, ) -> tuple[torch.Tensor, dict[str, float], torch.Tensor]: """ Compute loss for K-step autoregressive rollout. @@ -329,7 +332,7 @@ def _compute_rollout_loss( collider_k = masks[:, k, 1:2, :, :] model_input = torch.cat([state_current, state_prev, emitter_k, collider_k], dim=1) - pred = self.model(model_input) # (B, 3, H, W) + pred = self.model(model_input, cond) # (B, 3, H, W) if k == K - 1: target_final = y_seq[:, k, :, :, :] @@ -357,7 +360,7 @@ def _compute_rollout_loss( collider_k = masks[:, k, 1:2, :, :] model_input = torch.cat([state_current, state_prev, emitter_k, collider_k], dim=1) - pred = self.model(model_input) # (B, 3, H, W) + pred = self.model(model_input, cond) # (B, 3, H, W) target_k = y_seq[:, k, :, :, :] loss_k, loss_dict_k = self.criterion(pred, target_k, model_input) From 4ddd9f55409f19120cf106d609344ec34219fef8 Mon Sep 17 00:00:00 2001 From: ohmatheus Date: Thu, 19 Feb 2026 14:08:31 +0100 Subject: [PATCH 07/10] Updated training with post train test to check mse/vorticity level. + Moved all plot functions to a dedicated file to clean up trainer code. --- ml/dataset/npz_sequence.py | 4 +- ml/models/unet.py | 90 ++++++++++---- ml/scripts/train.py | 7 +- ml/training/test_evaluation.py | 109 +++++++++++++---- ml/training/trainer.py | 206 ++------------------------------- ml/training/training_plots.py | 165 ++++++++++++++++++++++++++ 6 files changed, 338 insertions(+), 243 deletions(-) create mode 100644 ml/training/training_plots.py diff --git a/ml/dataset/npz_sequence.py b/ml/dataset/npz_sequence.py index 5a6607f..b07cff5 100644 --- a/ml/dataset/npz_sequence.py +++ b/ml/dataset/npz_sequence.py @@ -445,7 +445,9 @@ def __len__(self) -> int: def __getitem__( self, idx: int - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ): si, t = self._index[idx] path = self.seq_paths[si] diff --git a/ml/models/unet.py b/ml/models/unet.py index b6c70a7..0448f42 100644 --- a/ml/models/unet.py +++ b/ml/models/unet.py @@ -49,7 +49,7 @@ def __init__(self, cond_dim: int) -> None: ) def forward(self, c: torch.Tensor) -> torch.Tensor: - return self.mlp(c.unsqueeze(-1)) # (B,) → (B,1) → (B, cond_dim) + return cast("torch.Tensor", self.mlp(c.unsqueeze(-1))) # (B,) → (B,1) → (B, cond_dim) class FiLMLayer(nn.Module): @@ -58,7 +58,7 @@ def __init__(self, ch: int, cond_dim: int) -> None: self.proj = nn.Linear(cond_dim, 2 * ch) def forward(self, x: torch.Tensor, cond_emb: torch.Tensor) -> torch.Tensor: - gamma, beta = self.proj(cond_emb).chunk(2, dim=-1) # each (B, C) + gamma, beta = cast("torch.Tensor", self.proj(cond_emb)).chunk(2, dim=-1) # each (B, C) return gamma.unsqueeze(-1).unsqueeze(-1) * x + beta.unsqueeze(-1).unsqueeze(-1) @@ -117,8 +117,14 @@ def __init__( ) -> None: super().__init__() self.b = ConvBlock( - ch, ch, norm=norm, act=act, groups=groups, dropout=dropout, - padding_mode=padding_mode, film_cond_dim=film_cond_dim, + ch, + ch, + norm=norm, + act=act, + groups=groups, + dropout=dropout, + padding_mode=padding_mode, + film_cond_dim=film_cond_dim, ) def forward(self, x: torch.Tensor, cond_emb: torch.Tensor | None = None) -> torch.Tensor: @@ -149,12 +155,25 @@ def __init__( ) -> None: super().__init__() self.block = ConvBlock( - in_ch, out_ch, norm=norm, act=act, groups=groups, dropout=dropout, - padding_mode=padding_mode, film_cond_dim=film_cond_dim, + in_ch, + out_ch, + norm=norm, + act=act, + groups=groups, + dropout=dropout, + padding_mode=padding_mode, + film_cond_dim=film_cond_dim, ) self.res = ( - ResBlock(out_ch, norm=norm, act=act, groups=groups, dropout=dropout, - padding_mode=padding_mode, film_cond_dim=film_cond_dim) + ResBlock( + out_ch, + norm=norm, + act=act, + groups=groups, + dropout=dropout, + padding_mode=padding_mode, + film_cond_dim=film_cond_dim, + ) if use_residual else nn.Identity() ) @@ -200,12 +219,25 @@ def __init__( ) self.block = ConvBlock( - in_ch + skip_ch, out_ch, norm=norm, act=act, groups=groups, dropout=dropout, - padding_mode=padding_mode, film_cond_dim=film_cond_dim, + in_ch + skip_ch, + out_ch, + norm=norm, + act=act, + groups=groups, + dropout=dropout, + padding_mode=padding_mode, + film_cond_dim=film_cond_dim, ) self.res = ( - ResBlock(out_ch, norm=norm, act=act, groups=groups, dropout=dropout, - padding_mode=padding_mode, film_cond_dim=film_cond_dim) + ResBlock( + out_ch, + norm=norm, + act=act, + groups=groups, + dropout=dropout, + padding_mode=padding_mode, + film_cond_dim=film_cond_dim, + ) if use_residual else nn.Identity() ) @@ -282,10 +314,15 @@ def __init__(self, cfg: UNetConfig | None = None) -> None: out_ch = ch * 2 downs.append( Down( - ch, out_ch, - norm=self.cfg.norm, act=self.cfg.act, groups=self.cfg.group_norm_groups, - dropout=self.cfg.dropout, use_residual=self.cfg.use_residual, - downsample=self.cfg.downsample, padding_mode=self.cfg.padding_mode, + ch, + out_ch, + norm=self.cfg.norm, + act=self.cfg.act, + groups=self.cfg.group_norm_groups, + dropout=self.cfg.dropout, + use_residual=self.cfg.use_residual, + downsample=self.cfg.downsample, + padding_mode=self.cfg.padding_mode, film_cond_dim=film_cond_dim, ) ) @@ -298,8 +335,11 @@ def __init__(self, cfg: UNetConfig | None = None) -> None: mids.append( ResBlock( ch, - norm=self.cfg.norm, act=self.cfg.act, groups=self.cfg.group_norm_groups, - dropout=self.cfg.dropout, padding_mode=self.cfg.padding_mode, + norm=self.cfg.norm, + act=self.cfg.act, + groups=self.cfg.group_norm_groups, + dropout=self.cfg.dropout, + padding_mode=self.cfg.padding_mode, film_cond_dim=film_cond_dim, ) ) @@ -311,10 +351,16 @@ def __init__(self, cfg: UNetConfig | None = None) -> None: out_ch = skip_ch // 2 ups.append( Up( - ch, skip_ch, out_ch, - upsample=self.cfg.upsample, norm=self.cfg.norm, act=self.cfg.act, - groups=self.cfg.group_norm_groups, dropout=self.cfg.dropout, - use_residual=self.cfg.use_residual, padding_mode=self.cfg.padding_mode, + ch, + skip_ch, + out_ch, + upsample=self.cfg.upsample, + norm=self.cfg.norm, + act=self.cfg.act, + groups=self.cfg.group_norm_groups, + dropout=self.cfg.dropout, + use_residual=self.cfg.use_residual, + padding_mode=self.cfg.padding_mode, film_cond_dim=film_cond_dim, ) ) diff --git a/ml/scripts/train.py b/ml/scripts/train.py index 4560152..62062af 100644 --- a/ml/scripts/train.py +++ b/ml/scripts/train.py @@ -13,7 +13,7 @@ from dataset.npz_sequence import FluidNPZSequenceDataset from models.unet import UNet, UNetConfig from scripts.variant_manager import VariantManager -from training.test_evaluation import run_rollout_evaluation, run_test_evaluation +from training.test_evaluation import run_rollout_evaluation, run_test_evaluation, run_vorticity_mse_evaluation from training.trainer import Trainer from utils.seed import set_seed @@ -258,6 +258,11 @@ def train_single_variant( config=config, device=config.device, ) + run_vorticity_mse_evaluation( + model=model, + config=config, + device=config.device, + ) print(f"\nTraining complete: {config.variant.full_model_name}") print(f"Best model: {config.checkpoint_dir_variant / 'best_model.pth'}") diff --git a/ml/training/test_evaluation.py b/ml/training/test_evaluation.py index 5a3172b..90b944c 100644 --- a/ml/training/test_evaluation.py +++ b/ml/training/test_evaluation.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import shutil import tempfile from pathlib import Path @@ -112,21 +113,20 @@ def run_test_evaluation( model.eval() with torch.no_grad(): pbar = tqdm(test_loader, desc="Test evaluation", leave=False) - for inputs, targets in pbar: + for inputs, targets, cond in pbar: inputs = inputs.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) + cond = cond.to(device, non_blocking=True) - # Model prediction if config.amp_enabled and device == "cuda": with autocast(device_type=device): - outputs = model(inputs) + outputs = model(inputs, cond) else: - outputs = model(inputs) + outputs = model(inputs, cond) model_metrics = _compute_batch_metrics(outputs, targets, inputs, config) model_tracker.update(model_metrics) - # Persistence baseline: predict current frame = next frame persistence_pred = inputs[:, 0:3, :, :] persistence_metrics = _compute_batch_metrics(persistence_pred, targets, inputs, config) persistence_tracker.update(persistence_metrics) @@ -134,7 +134,6 @@ def run_test_evaluation( model_avg = model_tracker.compute_averages() persistence_avg = persistence_tracker.compute_averages() - # Log to MLflow mlflow.log_metrics({f"test_{k}": v for k, v in model_avg.items()}) mlflow.log_metrics({f"test_persistence_{k}": v for k, v in persistence_avg.items()}) @@ -197,6 +196,7 @@ def _run_single_rollout( device: str, use_amp: bool, config: TrainingConfig, + cond_scalar: float | None = None, ) -> list[dict[str, float]]: d = data["density"] vx = data["velx"] @@ -209,7 +209,6 @@ def norm(arr: np.ndarray, key: str) -> np.ndarray: return arr / norm_scales[key] return arr - # Initial state d_t = norm(d[t_start], "S_density") d_tminus = norm(d[t_start - 1], "S_density") vx_t = norm(vx[t_start], "S_velx") @@ -217,6 +216,7 @@ def norm(arr: np.ndarray, key: str) -> np.ndarray: state_current = torch.tensor(np.stack([d_t, vx_t, vz_t], axis=0), dtype=torch.float32, device=device).unsqueeze(0) state_prev = torch.tensor(d_tminus[np.newaxis], dtype=torch.float32, device=device).unsqueeze(0) + cond_t = torch.tensor([cond_scalar or 0.0], dtype=torch.float32, device=device) step_metrics: list[dict[str, float]] = [] @@ -231,11 +231,10 @@ def norm(arr: np.ndarray, key: str) -> np.ndarray: if use_amp and device == "cuda": with autocast(device_type=device): - pred = model(model_input) + pred = model(model_input, cond_t) else: - pred = model(model_input) + pred = model(model_input, cond_t) - # GT for this step gt_d = norm(d[t_next], "S_density") gt_vx = norm(vx[t_next], "S_velx") gt_vz = norm(vz[t_next], "S_velz") @@ -285,7 +284,6 @@ def norm(arr: np.ndarray, key: str) -> np.ndarray: } ) - # Autoregressive update state_prev = state_current[:, 0:1, :, :] state_current = pred @@ -303,7 +301,6 @@ def _plot_rollout_degradation(avg_per_step: list[dict[str, float]]) -> Figure: fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(12, 5)) - # Left: MSE + SSIM ax_mse = ax_left ax_mse.plot(steps, mse_d, color="blue", linewidth=2, label="MSE Density") ax_mse.plot(steps, mse_vx, color="orange", linewidth=2, label="MSE Vel-X") @@ -322,10 +319,8 @@ def _plot_rollout_degradation(avg_per_step: list[dict[str, float]]) -> Figure: lines1, labels1 = ax_mse.get_legend_handles_labels() lines2, labels2 = ax_ssim.get_legend_handles_labels() ax_mse.legend(lines1 + lines2, labels1 + labels2, loc="best", fontsize=9) - ax_mse.set_title("MSE per Channel + SSIM") - # Right: Divergence + Gradient ax_div = ax_right ax_div.plot(steps, div_norm, color="red", linewidth=2, label="Divergence Norm") ax_div.set_xlabel("Autoregressive Step") @@ -341,7 +336,6 @@ def _plot_rollout_degradation(avg_per_step: list[dict[str, float]]) -> Figure: lines1, labels1 = ax_div.get_legend_handles_labels() lines2, labels2 = ax_grad.get_legend_handles_labels() ax_div.legend(lines1 + lines2, labels1 + labels2, loc="best", fontsize=9) - ax_div.set_title("Divergence Norm + Gradient L1") plt.suptitle("Rollout Degradation (30 steps)", fontsize=13, fontweight="bold") @@ -385,13 +379,17 @@ def run_rollout_evaluation( print(f"{'=' * 70}") print(f"Sequences: {len(test_paths)} | Steps: {ROLLOUT_STEPS} | Starting points: {len(STARTING_POINTS)}") - # Collect all rollout results: [step][rollout_idx] -> metrics all_rollouts: list[list[dict[str, float]]] = [[] for _ in range(ROLLOUT_STEPS)] total_rollouts = 0 model.eval() with torch.no_grad(): for seq_path in tqdm(test_paths, desc="Rollout evaluation", leave=False): + meta_path = seq_path.with_name(seq_path.stem + ".meta.json") + cond_scalar: float | None = None + if meta_path.exists(): + cond_scalar = json.loads(meta_path.read_text()).get("vorticity") + with np.load(seq_path) as npz: data = { "density": npz["density"].astype(np.float32), @@ -406,13 +404,12 @@ def run_rollout_evaluation( for t_start in starting_frames: step_metrics = _run_single_rollout( - model, data, t_start, norm_scales, device, config.amp_enabled, config + model, data, t_start, norm_scales, device, config.amp_enabled, config, cond_scalar ) for k, m in enumerate(step_metrics): all_rollouts[k].append(m) total_rollouts += 1 - # Average per step avg_per_step: list[dict[str, float]] = [] for k in range(ROLLOUT_STEPS): if not all_rollouts[k]: @@ -442,7 +439,6 @@ def run_rollout_evaluation( } avg_per_step.append(avg) - # Log scalars to MLflow mse_at_30 = avg_per_step[-1]["mse_density"] ssim_at_30 = avg_per_step[-1]["ssim_density"] div_at_30 = avg_per_step[-1]["divergence_norm"] @@ -452,11 +448,9 @@ def run_rollout_evaluation( mlflow.log_metric("test_rollout_divergence_norm_step30", div_at_30) mlflow.log_metric("test_rollout_gradient_l1_step30", grad_at_30) - # Plot and log fig = _plot_rollout_degradation(avg_per_step) log_artifact_flat(fig, "rollout_degradation.png", dpi=100) - # Console output print(f"\nTotal rollouts: {total_rollouts} ({len(test_paths)} sequences x {len(STARTING_POINTS)} starts)") print( f"\n{'Step':<8} {'MSE Density':>12} {'MSE Vel-X':>12} {'MSE Vel-Y':>12} " @@ -474,3 +468,76 @@ def run_rollout_evaluation( print("=" * 100) return {"avg_per_step": avg_per_step} + + +def run_vorticity_mse_evaluation( + model: nn.Module, + config: TrainingConfig, + device: str, +) -> None: + npz_dir = config.npz_dir / str(project_config.simulation.grid_resolution) + + test_ds = FluidNPZSequenceDataset( + npz_dir=npz_dir, + split="test", + normalize=config.normalize, + is_training=False, + preload=config.preload_dataset, + rollout_steps=1, + ) + + if not any(v is not None for v in test_ds._seq_scalars): + print("No vorticity scalars found in test set, skipping vorticity MSE evaluation.") + return + + test_loader = DataLoader(test_ds, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) + + # {vorticity_level: [mse_density, mse_velx, mse_velz]} + grouped: dict[float, list[list[float]]] = {} + + model.eval() + with torch.no_grad(): + for inputs, targets, cond in tqdm(test_loader, desc="Vorticity MSE eval", leave=False): + inputs = inputs.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + cond = cond.to(device, non_blocking=True) + + if config.amp_enabled and device == "cuda": + with autocast(device_type=device): + outputs = model(inputs, cond) + else: + outputs = model(inputs, cond) + + for i in range(outputs.shape[0]): + level = round(cond[i].item(), 2) + mse_d = torch.mean((outputs[i, 0] - targets[i, 0]) ** 2).item() + mse_vx = torch.mean((outputs[i, 1] - targets[i, 1]) ** 2).item() + mse_vz = torch.mean((outputs[i, 2] - targets[i, 2]) ** 2).item() + grouped.setdefault(level, []).append([mse_d, mse_vx, mse_vz]) + + if not grouped: + return + + levels = sorted(grouped.keys()) + means = {lv: np.mean(grouped[lv], axis=0) for lv in levels} # (3,) per level + + fig, ax = plt.subplots(figsize=(max(6, len(levels) * 1.5), 5)) + x = np.arange(len(levels)) + width = 0.25 + colors = ["#2196F3", "#FF9800", "#4CAF50"] + labels = ["Density", "Vel-X", "Vel-Z"] + + for ci, (color, label) in enumerate(zip(colors, labels, strict=True)): + vals = [means[lv][ci] for lv in levels] + ax.bar(x + ci * width, vals, width, label=label, color=color, alpha=0.85) + + ax.set_xticks(x + width) + ax.set_xticklabels([str(lv) for lv in levels]) + ax.set_xlabel("Vorticity Level") + ax.set_ylabel("MSE") + ax.set_title("MSE by Vorticity Level (Test Set)") + ax.legend() + ax.grid(True, axis="y", alpha=0.3) + plt.tight_layout() + + log_artifact_flat(fig, "vorticity_mse_by_level.png", dpi=90) diff --git a/ml/training/trainer.py b/ml/training/trainer.py index decc0e0..ccbd4bf 100644 --- a/ml/training/trainer.py +++ b/ml/training/trainer.py @@ -2,11 +2,9 @@ from pathlib import Path from typing import Any, cast -import matplotlib.pyplot as plt import mlflow import torch import torch.nn as nn -from matplotlib.figure import Figure from torch.amp.autocast_mode import autocast from torch.amp.grad_scaler import GradScaler from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, StepLR @@ -29,6 +27,7 @@ ) from training.physics_loss import PhysicsAwareLoss, StencilMode from training.test_evaluation import log_artifact_flat +from training.training_plots import plot_loss_components, plot_metrics_grid, plot_training_history class Trainer: @@ -414,197 +413,10 @@ def _create_dataloader_with_rollout_steps(self, K: int, is_training: bool) -> Da pin_memory=True if self.device == "cuda" else False, ) - def plot_training_history(self) -> Figure: - num_epochs = len(self.history["train_total"]) - - if num_epochs == 0: - raise ValueError("No training history to plot - training never started or was interrupted at epoch 0") - - epochs = list(range(1, num_epochs + 1)) - + def _best_epoch(self) -> int: if self.early_stopping is not None and self.early_stopping.best_epoch >= 0: - best_epoch = self.early_stopping.best_epoch + 1 - else: - min_val_idx = self.history["val_total"].index(min(self.history["val_total"])) - best_epoch = min_val_idx + 1 - - fig, ax_loss = plt.subplots(1, 1, figsize=(8, 5)) - - ax_loss.plot( - epochs, - self.history["train_total"], - color="#0066CC", - linewidth=2.0, - label="Train Total", - alpha=0.9, - ) - ax_loss.plot( - epochs, - self.history["val_total"], - color="#FF5500", - linewidth=2.0, - label="Val Total", - alpha=0.9, - ) - - ax_loss.axvline( - x=best_epoch, color="red", linestyle="--", linewidth=2.0, label=f"Best Epoch ({best_epoch})", alpha=0.7 - ) - - ax_loss.set_xlabel("Epoch", fontsize=12) - ax_loss.set_ylabel("Loss", fontsize=12, color="#333333") - ax_loss.set_title("Training and Validation Losses with Learning Rate", fontsize=14, fontweight="bold") - ax_loss.tick_params(axis="y", labelcolor="#333333") - ax_loss.grid(True, alpha=0.3) - - loss_range = max(self.history["train_total"]) / (min(self.history["train_total"]) + 1e-8) - if loss_range > 100: - ax_loss.set_yscale("log") - - ax_lr = ax_loss.twinx() - ax_lr.plot( - epochs, self.history["learning_rate"], color="#2ca02c", linewidth=2.0, label="Learning Rate", linestyle="--" - ) - ax_lr.set_ylabel("Learning Rate", fontsize=12, color="#2ca02c") - ax_lr.tick_params(axis="y", labelcolor="#2ca02c") - ax_lr.set_yscale("log") - - lines1, labels1 = ax_loss.get_legend_handles_labels() - lines2, labels2 = ax_lr.get_legend_handles_labels() - ax_loss.legend(lines1 + lines2, labels1 + labels2, loc="best", fontsize=10, framealpha=0.9) - - ax_loss.set_xticks(range(1, num_epochs + 1, max(1, num_epochs // 10))) - - plt.tight_layout() - - return fig - - def plot_loss_components(self) -> Figure: - num_epochs = len(self.history["train_total"]) - if num_epochs == 0: - raise ValueError("No training history to plot") - - epochs = list(range(1, num_epochs + 1)) - - if self.early_stopping is not None and self.early_stopping.best_epoch >= 0: - best_epoch = self.early_stopping.best_epoch + 1 - else: - min_val_idx = self.history["val_total"].index(min(self.history["val_total"])) - best_epoch = min_val_idx + 1 - - components = [] - if "train_mse" in self.history and len(self.history["train_mse"]) > 0: - components.append(("MSE", "train_mse", "val_mse")) - if "train_divergence" in self.history and len(self.history["train_divergence"]) > 0: - components.append(("Divergence", "train_divergence", "val_divergence")) - if "train_gradient" in self.history and len(self.history["train_gradient"]) > 0: - components.append(("Gradient", "train_gradient", "val_gradient")) - if "train_emitter" in self.history and len(self.history["train_emitter"]) > 0: - components.append(("Emitter", "train_emitter", "val_emitter")) - - n_components = len(components) - fig, axes = plt.subplots(2, 2, figsize=(10, 8)) - axes = axes.flatten() - - for idx, (name, train_key, val_key) in enumerate(components): - ax = axes[idx] - ax.plot(epochs, self.history[train_key], color="#0066CC", linewidth=2.0, label="Train", alpha=0.9) - ax.plot(epochs, self.history[val_key], color="#FF5500", linewidth=2.0, label="Val", alpha=0.9) - ax.axvline(x=best_epoch, color="red", linestyle="--", linewidth=1.5, alpha=0.7) - - ax.set_xlabel("Epoch", fontsize=10) - ax.set_ylabel(f"{name} Loss", fontsize=10) - ax.set_title(f"{name} Loss", fontsize=12, fontweight="bold") - ax.grid(True, alpha=0.3) - ax.legend(loc="best", fontsize=9) - - loss_range = max(self.history[train_key]) / (min(self.history[train_key]) + 1e-8) - if loss_range > 100: - ax.set_yscale("log") - - for idx in range(n_components, 4): - axes[idx].axis("off") - - plt.suptitle("Loss Components Over Epochs", fontsize=14, fontweight="bold") - plt.tight_layout() - - return fig - - def plot_metrics_grid(self) -> Figure: - """Create comprehensive metrics grid plot.""" - num_epochs = len(self.history["val_total"]) - if num_epochs == 0: - raise ValueError("No history to plot") - - epochs = list(range(1, num_epochs + 1)) - - fig, axes = plt.subplots(2, 4, figsize=(16, 8)) - axes = axes.flatten() - - axes[0].plot(epochs, self.history["val_mse_density"], label="Density", linewidth=2) - axes[0].plot(epochs, self.history["val_mse_velx"], label="Vel-X", linewidth=2) - axes[0].plot(epochs, self.history["val_mse_vely"], label="Vel-Y", linewidth=2) - axes[0].set_title("Per-Channel MSE") - axes[0].set_xlabel("Epoch") - axes[0].set_ylabel("MSE") - axes[0].legend() - axes[0].grid(True, alpha=0.3) - - axes[1].plot(epochs, self.history["val_divergence_norm"], color="red", linewidth=2) - axes[1].axhline(y=0.1, color="green", linestyle="--", label="Target < 0.1") - axes[1].set_title("Divergence Norm") - axes[1].set_xlabel("Epoch") - axes[1].set_ylabel("||∇·v||₂") - axes[1].legend() - axes[1].grid(True, alpha=0.3) - - axes[2].plot(epochs, self.history["val_kinetic_energy"], color="blue", linewidth=2) - axes[2].set_title("Kinetic Energy") - axes[2].set_xlabel("Epoch") - axes[2].set_ylabel("KE (trend monitor)") - axes[2].grid(True, alpha=0.3) - - axes[3].plot(epochs, self.history["val_collider_violation"], color="orange", linewidth=2) - axes[3].axhline(y=0.01, color="green", linestyle="--", label="Target < 0.01") - axes[3].set_title("Collider Violation") - axes[3].set_xlabel("Epoch") - axes[3].set_ylabel("Density in collider") - axes[3].legend() - axes[3].grid(True, alpha=0.3) - - axes[4].plot(epochs, self.history["val_emitter_accuracy"], color="purple", linewidth=2) - axes[4].axhline(y=0.1, color="green", linestyle="--", label="Target < 0.1") - axes[4].set_title("Emitter Density Accuracy") - axes[4].set_xlabel("Epoch") - axes[4].set_ylabel("Injection error") - axes[4].legend() - axes[4].grid(True, alpha=0.3) - - axes[5].plot(epochs, self.history["val_ssim_density"], color="teal", linewidth=2) - axes[5].axhline(y=0.9, color="green", linestyle="--", label="Target > 0.9") - axes[5].set_title("SSIM (Density)") - axes[5].set_xlabel("Epoch") - axes[5].set_ylabel("SSIM") - axes[5].legend() - axes[5].grid(True, alpha=0.3) - - axes[6].plot(epochs, self.history["val_gradient_l1"], color="brown", linewidth=2) - axes[6].set_title("Gradient L1 (Edge Sharpness)") - axes[6].set_xlabel("Epoch") - axes[6].set_ylabel("L1 gradient error") - axes[6].grid(True, alpha=0.3) - - axes[7].axis("off") - - # Check if log scale needed for MSE - mse_range = max(self.history["val_mse_density"]) / (min(self.history["val_mse_density"]) + 1e-8) - if mse_range > 100: - axes[0].set_yscale("log") - - plt.suptitle("Validation Metrics", fontsize=16, fontweight="bold") - plt.tight_layout() - - return fig + return self.early_stopping.best_epoch + 1 + return self.history["val_total"].index(min(self.history["val_total"])) + 1 def train(self) -> None: print(f"Starting training for {self.config.epochs} epochs...") @@ -694,23 +506,21 @@ def train(self) -> None: else: self.save_checkpoint(self.config.epochs - 1, final=True) + best_epoch = self._best_epoch() try: - fig = self.plot_training_history() - log_artifact_flat(fig, "training_loss_and_lr.png", dpi=144) + log_artifact_flat(plot_training_history(self.history, best_epoch), "training_loss_and_lr.png", dpi=144) print("Training history plot saved to MLflow artifacts") except Exception as e: print(f"Warning: Failed to generate/save training history plot - {e}") try: - fig_metrics = self.plot_metrics_grid() - log_artifact_flat(fig_metrics, "validation_metrics_grid.png", dpi=90) + log_artifact_flat(plot_metrics_grid(self.history), "validation_metrics_grid.png", dpi=90) print("Metrics grid plot saved to MLflow artifacts") except Exception as e: print(f"Warning: Failed to generate/save metrics grid plot - {e}") try: - fig_loss_components = self.plot_loss_components() - log_artifact_flat(fig_loss_components, "loss_components.png", dpi=90) + log_artifact_flat(plot_loss_components(self.history, best_epoch), "loss_components.png", dpi=90) print("Loss components plot saved to MLflow artifacts") except Exception as e: print(f"Warning: Failed to generate/save loss components plot - {e}") diff --git a/ml/training/training_plots.py b/ml/training/training_plots.py new file mode 100644 index 0000000..bc88aae --- /dev/null +++ b/ml/training/training_plots.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import matplotlib.pyplot as plt + +if TYPE_CHECKING: + from matplotlib.figure import Figure + + +def plot_training_history(history: dict[str, list[float]], best_epoch: int) -> Figure: + num_epochs = len(history["train_total"]) + if num_epochs == 0: + raise ValueError("No training history to plot - training never started or was interrupted at epoch 0") + + epochs = list(range(1, num_epochs + 1)) + + fig, ax_loss = plt.subplots(1, 1, figsize=(8, 5)) + + ax_loss.plot(epochs, history["train_total"], color="#0066CC", linewidth=2.0, label="Train Total", alpha=0.9) + ax_loss.plot(epochs, history["val_total"], color="#FF5500", linewidth=2.0, label="Val Total", alpha=0.9) + ax_loss.axvline( + x=best_epoch, color="red", linestyle="--", linewidth=2.0, label=f"Best Epoch ({best_epoch})", alpha=0.7 + ) + + ax_loss.set_xlabel("Epoch", fontsize=12) + ax_loss.set_ylabel("Loss", fontsize=12, color="#333333") + ax_loss.set_title("Training and Validation Losses with Learning Rate", fontsize=14, fontweight="bold") + ax_loss.tick_params(axis="y", labelcolor="#333333") + ax_loss.grid(True, alpha=0.3) + + loss_range = max(history["train_total"]) / (min(history["train_total"]) + 1e-8) + if loss_range > 100: + ax_loss.set_yscale("log") + + ax_lr = ax_loss.twinx() + ax_lr.plot(epochs, history["learning_rate"], color="#2ca02c", linewidth=2.0, label="Learning Rate", linestyle="--") + ax_lr.set_ylabel("Learning Rate", fontsize=12, color="#2ca02c") + ax_lr.tick_params(axis="y", labelcolor="#2ca02c") + ax_lr.set_yscale("log") + + lines1, labels1 = ax_loss.get_legend_handles_labels() + lines2, labels2 = ax_lr.get_legend_handles_labels() + ax_loss.legend(lines1 + lines2, labels1 + labels2, loc="best", fontsize=10, framealpha=0.9) + ax_loss.set_xticks(range(1, num_epochs + 1, max(1, num_epochs // 10))) + + plt.tight_layout() + return fig + + +def plot_loss_components(history: dict[str, list[float]], best_epoch: int) -> Figure: + num_epochs = len(history["train_total"]) + if num_epochs == 0: + raise ValueError("No training history to plot") + + epochs = list(range(1, num_epochs + 1)) + + components = [] + if history.get("train_mse"): + components.append(("MSE", "train_mse", "val_mse")) + if history.get("train_divergence"): + components.append(("Divergence", "train_divergence", "val_divergence")) + if history.get("train_gradient"): + components.append(("Gradient", "train_gradient", "val_gradient")) + if history.get("train_emitter"): + components.append(("Emitter", "train_emitter", "val_emitter")) + + n_components = len(components) + fig, axes = plt.subplots(2, 2, figsize=(10, 8)) + axes = axes.flatten() + + for idx, (name, train_key, val_key) in enumerate(components): + ax = axes[idx] + ax.plot(epochs, history[train_key], color="#0066CC", linewidth=2.0, label="Train", alpha=0.9) + ax.plot(epochs, history[val_key], color="#FF5500", linewidth=2.0, label="Val", alpha=0.9) + ax.axvline(x=best_epoch, color="red", linestyle="--", linewidth=1.5, alpha=0.7) + ax.set_xlabel("Epoch", fontsize=10) + ax.set_ylabel(f"{name} Loss", fontsize=10) + ax.set_title(f"{name} Loss", fontsize=12, fontweight="bold") + ax.grid(True, alpha=0.3) + ax.legend(loc="best", fontsize=9) + loss_range = max(history[train_key]) / (min(history[train_key]) + 1e-8) + if loss_range > 100: + ax.set_yscale("log") + + for idx in range(n_components, 4): + axes[idx].axis("off") + + plt.suptitle("Loss Components Over Epochs", fontsize=14, fontweight="bold") + plt.tight_layout() + return fig + + +def plot_metrics_grid(history: dict[str, list[float]]) -> Figure: + num_epochs = len(history["val_total"]) + if num_epochs == 0: + raise ValueError("No history to plot") + + epochs = list(range(1, num_epochs + 1)) + + fig, axes = plt.subplots(2, 4, figsize=(16, 8)) + axes = axes.flatten() + + axes[0].plot(epochs, history["val_mse_density"], label="Density", linewidth=2) + axes[0].plot(epochs, history["val_mse_velx"], label="Vel-X", linewidth=2) + axes[0].plot(epochs, history["val_mse_vely"], label="Vel-Y", linewidth=2) + axes[0].set_title("Per-Channel MSE") + axes[0].set_xlabel("Epoch") + axes[0].set_ylabel("MSE") + axes[0].legend() + axes[0].grid(True, alpha=0.3) + + axes[1].plot(epochs, history["val_divergence_norm"], color="red", linewidth=2) + axes[1].axhline(y=0.1, color="green", linestyle="--", label="Target < 0.1") + axes[1].set_title("Divergence Norm") + axes[1].set_xlabel("Epoch") + axes[1].set_ylabel("||∇·v||₂") + axes[1].legend() + axes[1].grid(True, alpha=0.3) + + axes[2].plot(epochs, history["val_kinetic_energy"], color="blue", linewidth=2) + axes[2].set_title("Kinetic Energy") + axes[2].set_xlabel("Epoch") + axes[2].set_ylabel("KE (trend monitor)") + axes[2].grid(True, alpha=0.3) + + axes[3].plot(epochs, history["val_collider_violation"], color="orange", linewidth=2) + axes[3].axhline(y=0.01, color="green", linestyle="--", label="Target < 0.01") + axes[3].set_title("Collider Violation") + axes[3].set_xlabel("Epoch") + axes[3].set_ylabel("Density in collider") + axes[3].legend() + axes[3].grid(True, alpha=0.3) + + axes[4].plot(epochs, history["val_emitter_accuracy"], color="purple", linewidth=2) + axes[4].axhline(y=0.1, color="green", linestyle="--", label="Target < 0.1") + axes[4].set_title("Emitter Density Accuracy") + axes[4].set_xlabel("Epoch") + axes[4].set_ylabel("Injection error") + axes[4].legend() + axes[4].grid(True, alpha=0.3) + + axes[5].plot(epochs, history["val_ssim_density"], color="teal", linewidth=2) + axes[5].axhline(y=0.9, color="green", linestyle="--", label="Target > 0.9") + axes[5].set_title("SSIM (Density)") + axes[5].set_xlabel("Epoch") + axes[5].set_ylabel("SSIM") + axes[5].legend() + axes[5].grid(True, alpha=0.3) + + axes[6].plot(epochs, history["val_gradient_l1"], color="brown", linewidth=2) + axes[6].set_title("Gradient L1 (Edge Sharpness)") + axes[6].set_xlabel("Epoch") + axes[6].set_ylabel("L1 gradient error") + axes[6].grid(True, alpha=0.3) + + axes[7].axis("off") + + mse_range = max(history["val_mse_density"]) / (min(history["val_mse_density"]) + 1e-8) + if mse_range > 100: + axes[0].set_yscale("log") + + plt.suptitle("Validation Metrics", fontsize=16, fontweight="bold") + plt.tight_layout() + return fig From b55a91417d6bf8f12d9c1241b5b52f7b2de50330 Mon Sep 17 00:00:00 2001 From: ohmatheus Date: Thu, 19 Feb 2026 14:11:46 +0100 Subject: [PATCH 08/10] Updated onnx export with the new FilM input --- ml/scripts/export_to_onnx.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/ml/scripts/export_to_onnx.py b/ml/scripts/export_to_onnx.py index 93d284a..e5ef608 100644 --- a/ml/scripts/export_to_onnx.py +++ b/ml/scripts/export_to_onnx.py @@ -114,6 +114,8 @@ def load_model_from_checkpoint(checkpoint_path: Path, device: str) -> UNet: use_residual = config.get("use_residual", True) bottleneck_blocks = config.get("bottleneck_blocks", 1) output_activation = config.get("output_activation", "linear_clamp") + use_film = config.get("use_film", False) + film_cond_dim = config.get("film_cond_dim", 128) model = UNet( cfg=UNetConfig( @@ -131,6 +133,8 @@ def load_model_from_checkpoint(checkpoint_path: Path, device: str) -> UNet: use_residual=use_residual, bottleneck_blocks=bottleneck_blocks, output_activation=output_activation, + use_film=use_film, + film_cond_dim=film_cond_dim, ) ) @@ -152,21 +156,37 @@ def export_model_to_onnx( output_path: Path, input_shape: tuple[int, int, int, int], device: str, + use_film: bool = False, ) -> None: output_path.parent.mkdir(parents=True, exist_ok=True) dummy_input = torch.randn(*input_shape, device=device) + inputs: tuple[torch.Tensor, ...] + if use_film: + dummy_cond = torch.zeros(input_shape[0], device=device) # (B,) + inputs = (dummy_input, dummy_cond) + input_names = ["input", "cond"] + dynamic_axes: dict[str, dict[int, str]] | None = { + "input": {0: "batch"}, + "cond": {0: "batch"}, + "output": {0: "batch"}, + } + else: + inputs = (dummy_input,) + input_names = ["input"] + dynamic_axes = None + torch.onnx.export( model, - (dummy_input,), + inputs, str(output_path), export_params=True, opset_version=ONNX_OPSET_VERSION, do_constant_folding=True, - input_names=["input"], + input_names=input_names, output_names=["output"], - dynamic_axes=None, # Fixed input size for now + dynamic_axes=dynamic_axes, ) file_size_mb = output_path.stat().st_size / (1024 * 1024) @@ -254,10 +274,12 @@ def export_checkpoint( variants_info = {} try: + model = load_model_from_checkpoint(checkpoint_path, device) + use_film = model.cfg.use_film + if not onnx_path.exists(): print(" Exporting FP32...") - model = load_model_from_checkpoint(checkpoint_path, device) - export_model_to_onnx(model, onnx_path, input_shape, device) + export_model_to_onnx(model, onnx_path, input_shape, device, use_film=use_film) else: print(" FP32 exists, skipping") From 3ba520a3fd07bfc760db440e351223cf2cdc3563 Mon Sep 17 00:00:00 2001 From: ohmatheus Date: Thu, 19 Feb 2026 14:22:34 +0100 Subject: [PATCH 09/10] Engine reads new onnx and change vorticity with a slider in RT --- config.yaml | 4 ++++ engine/include/Config.hpp | 20 +++++++++++++++++ engine/include/FluidScene.hpp | 5 +++++ engine/include/Simulation.hpp | 5 +++++ engine/src/Config.cpp | 16 ++++++++++++++ engine/src/FluidScene.cpp | 14 ++++++++++++ engine/src/Simulation.cpp | 41 ++++++++++++++++++++++++++++++----- 7 files changed, 100 insertions(+), 5 deletions(-) diff --git a/config.yaml b/config.yaml index 5236bde..4a6a030 100644 --- a/config.yaml +++ b/config.yaml @@ -14,6 +14,10 @@ simulation: fps: 24.0 grid_resolution: 128 input_channels: 6 #[density, velx, vely, density-1, emitter_mask, collider_mask] + vorticity_default: 0.2 + vorticity_min: 0.1 + vorticity_max: 0.4 + vorticity_step: 0.1 engine: window_width: 1250 diff --git a/engine/include/Config.hpp b/engine/include/Config.hpp index 891e4c5..93fc6ab 100644 --- a/engine/include/Config.hpp +++ b/engine/include/Config.hpp @@ -20,6 +20,10 @@ struct SimulationConfig float fps{30.0f}; int gridResolution{128}; int inputChannels{0}; + float vorticityDefault{0.2f}; + float vorticityMin{0.1f}; + float vorticityMax{0.4f}; + float vorticityStep{0.1f}; }; struct ModelsConfig @@ -65,6 +69,22 @@ class Config { return m_simulationConfig.inputChannels; } + float getVorticityDefault() const + { + return m_simulationConfig.vorticityDefault; + } + float getVorticityMin() const + { + return m_simulationConfig.vorticityMin; + } + float getVorticityMax() const + { + return m_simulationConfig.vorticityMax; + } + float getVorticityStep() const + { + return m_simulationConfig.vorticityStep; + } const std::filesystem::path& getModelsFolder() const { diff --git a/engine/include/FluidScene.hpp b/engine/include/FluidScene.hpp index dd69114..a652ea8 100644 --- a/engine/include/FluidScene.hpp +++ b/engine/include/FluidScene.hpp @@ -72,6 +72,11 @@ class FluidScene final : public Scene float m_velocityStrength{0.5f}; float m_velocityDecayPercent{5.0f}; + float m_vorticity{0.2f}; + float m_vorticityMin{0.1f}; + float m_vorticityMax{0.4f}; + float m_vorticityStep{0.1f}; + int m_prevMouseGridX{-1}; int m_prevMouseGridY{-1}; bool m_mousePressed{false}; diff --git a/engine/include/Simulation.hpp b/engine/include/Simulation.hpp index 705e4d4..b75c500 100644 --- a/engine/include/Simulation.hpp +++ b/engine/include/Simulation.hpp @@ -28,6 +28,9 @@ class Simulation return !m_useGpu; } + void setVorticity(float v); + float getVorticity() const; + const SimulationBuffer* getLatestState() const; float getAvgComputeTimeMs() const; @@ -58,6 +61,8 @@ class Simulation const std::atomic* m_sceneSnapshotPtr{nullptr}; + std::atomic m_vorticity{0.2f}; + std::atomic m_avgComputeTimeMs{0.0f}; float m_sumComputeTimeMs{0.0f}; int m_computeTimeSamples{0}; diff --git a/engine/src/Config.cpp b/engine/src/Config.cpp index 904f81a..a170aaf 100644 --- a/engine/src/Config.cpp +++ b/engine/src/Config.cpp @@ -48,6 +48,22 @@ template <> struct convert { config.inputChannels = node["input_channels"].as(); } + if (node["vorticity_default"]) + { + config.vorticityDefault = node["vorticity_default"].as(); + } + if (node["vorticity_min"]) + { + config.vorticityMin = node["vorticity_min"].as(); + } + if (node["vorticity_max"]) + { + config.vorticityMax = node["vorticity_max"].as(); + } + if (node["vorticity_step"]) + { + config.vorticityStep = node["vorticity_step"].as(); + } return true; } }; diff --git a/engine/src/FluidScene.cpp b/engine/src/FluidScene.cpp index 382f0a1..167a93c 100644 --- a/engine/src/FluidScene.cpp +++ b/engine/src/FluidScene.cpp @@ -4,6 +4,7 @@ #include "Profiling.hpp" #include "SceneState.hpp" #include +#include #include #include @@ -43,6 +44,10 @@ void FluidScene::onInit() m_renderer->initialize(); m_simulationFPS = config.getSimulationFPS(); + m_vorticity = config.getVorticityDefault(); + m_vorticityMin = config.getVorticityMin(); + m_vorticityMax = config.getVorticityMax(); + m_vorticityStep = config.getVorticityStep(); std::cout << "FluidScene initialized" << std::endl; } @@ -128,6 +133,15 @@ void FluidScene::onRenderUI() ImGui::SliderInt("Velocity Brush", &m_velocityBrushSize, 1, 15); ImGui::SliderFloat("Velocity Strength", &m_velocityStrength, 0.1f, 1.0f, "%.1f"); ImGui::SliderFloat("Velocity Decay %", &m_velocityDecayPercent, 1.0f, 10.0f, "%.1f%%"); + + ImGui::Separator(); + if (ImGui::SliderFloat("Vorticity", &m_vorticity, m_vorticityMin, m_vorticityMax, "%.2f")) + { + m_vorticity = std::round(m_vorticity / m_vorticityStep) * m_vorticityStep; + if (m_simulation) + m_simulation->setVorticity(m_vorticity); + } + ImGui::Checkbox("Debug Overlay (O)", &m_showDebugOverlay); if (ImGui::IsItemEdited() && m_renderer) { diff --git a/engine/src/Simulation.cpp b/engine/src/Simulation.cpp index d7a9dc0..93cf32e 100644 --- a/engine/src/Simulation.cpp +++ b/engine/src/Simulation.cpp @@ -15,6 +15,7 @@ Simulation::Simulation() const auto& config = Config::getInstance(); m_targetStepTime = 1.0f / config.getSimulationFPS(); m_useGpu = config.isGpuEnabled(); + m_vorticity.store(config.getVorticityDefault(), std::memory_order_relaxed); int resolution = config.getGridResolution(); @@ -156,6 +157,16 @@ void Simulation::setSceneSnapshot(const std::atomic* snapsho m_sceneSnapshotPtr = snapshot; } +void Simulation::setVorticity(float v) +{ + m_vorticity.store(v, std::memory_order_release); +} + +float Simulation::getVorticity() const +{ + return m_vorticity.load(std::memory_order_acquire); +} + void Simulation::workerLoop_() { PROFILE_SET_THREAD_NAME("Simulation Thread"); @@ -292,20 +303,40 @@ float Simulation::runInferenceStep_(SimulationBuffer* frontBuf, SimulationBuffer } Ort::AllocatorWithDefaultOptions allocator; - auto inputName = m_ortSession->GetInputNameAllocated(0, allocator); + auto inputName0 = m_ortSession->GetInputNameAllocated(0, allocator); auto outputName = m_ortSession->GetOutputNameAllocated(0, allocator); - - const char* inputNames[] = {inputName.get()}; const char* outputNames[] = {outputName.get()}; + const size_t numModelInputs = m_ortSession->GetInputCount(); + auto inferenceStart = Clock::now(); std::vector outputTensors; { PROFILE_SCOPE_NAMED("ONNX Runtime Execute"); PROFILE_ZONE_TEXT("Model Inference", 15); - outputTensors = m_ortSession->Run(Ort::RunOptions{nullptr}, inputNames, &inputTensor, 1, - outputNames, 1); + if (numModelInputs >= 2) + { + auto inputName1 = m_ortSession->GetInputNameAllocated(1, allocator); + float vorticity = m_vorticity.load(std::memory_order_acquire); + const int64_t condShape[] = {1}; + auto condMemInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + Ort::Value condTensor = Ort::Value::CreateTensor( + condMemInfo, &vorticity, 1, condShape, 1); + + const char* inputNames[] = {inputName0.get(), inputName1.get()}; + std::vector inputVec; + inputVec.push_back(std::move(inputTensor)); + inputVec.push_back(std::move(condTensor)); + outputTensors = m_ortSession->Run(Ort::RunOptions{nullptr}, inputNames, + inputVec.data(), 2, outputNames, 1); + } + else + { + const char* inputNames[] = {inputName0.get()}; + outputTensors = m_ortSession->Run(Ort::RunOptions{nullptr}, inputNames, + &inputTensor, 1, outputNames, 1); + } } auto inferenceEnd = Clock::now(); From 5c33acf9810424f6a6d36ed86cdd90ee3d13d536 Mon Sep 17 00:00:00 2001 From: ohmatheus Date: Thu, 19 Feb 2026 14:22:58 +0100 Subject: [PATCH 10/10] cpp linter --- engine/src/Simulation.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engine/src/Simulation.cpp b/engine/src/Simulation.cpp index 93cf32e..0b00bea 100644 --- a/engine/src/Simulation.cpp +++ b/engine/src/Simulation.cpp @@ -321,8 +321,8 @@ float Simulation::runInferenceStep_(SimulationBuffer* frontBuf, SimulationBuffer float vorticity = m_vorticity.load(std::memory_order_acquire); const int64_t condShape[] = {1}; auto condMemInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - Ort::Value condTensor = Ort::Value::CreateTensor( - condMemInfo, &vorticity, 1, condShape, 1); + Ort::Value condTensor = + Ort::Value::CreateTensor(condMemInfo, &vorticity, 1, condShape, 1); const char* inputNames[] = {inputName0.get(), inputName1.get()}; std::vector inputVec;