From c4721d454ccf5e9f98f5f005968828d8664512b0 Mon Sep 17 00:00:00 2001 From: ohmatheus Date: Tue, 17 Feb 2026 09:54:27 +0100 Subject: [PATCH 1/2] Added a stride to dataloader to avoid K training to see multiple time the same frame at each epoch. Also added an offset modulo to avoid each epoch to see the same 'sequence' of frame sequences. --- ml/config/base_config.yaml | 1 + ml/config/training_config.py | 8 ++++++++ ml/dataset/npz_sequence.py | 38 ++++++++++++++++++++++++------------ ml/scripts/train.py | 5 +++++ ml/training/trainer.py | 2 ++ 5 files changed, 41 insertions(+), 13 deletions(-) diff --git a/ml/config/base_config.yaml b/ml/config/base_config.yaml index d81a7e2..0da0dbd 100644 --- a/ml/config/base_config.yaml +++ b/ml/config/base_config.yaml @@ -51,6 +51,7 @@ physics_loss: # NOTE: rollout_step is FIXED per variant (K1_A=1, K2_A=2, K3_A=3, etc.) # No rollout_schedule needed - each K variant trains at fixed K rollout_step: 0 +rollout_stride: 1 rollout_weight_decay: 1.10 rollout_gradient_truncation: false validation_use_rollout_k: true diff --git a/ml/config/training_config.py b/ml/config/training_config.py index b495c66..a809c96 100644 --- a/ml/config/training_config.py +++ b/ml/config/training_config.py @@ -100,6 +100,7 @@ class TrainingConfig(BaseModel): # Multi-step rollout training rollout_step: int = 0 + rollout_stride: float = 1.0 # fraction of K used as stride (0.0–1.0], 1.0 = stride equals K rollout_weight_decay: float = 1.10 # not used if rollout_final_step_only is True # ∂total_loss/∂θ = # w₀/sum × ∂loss₀/∂θ [direct from step 0] @@ -120,6 +121,13 @@ def validate_rollout_step(cls, v: int) -> int: raise ValueError(f"rollout_step must be >= 0, got {v}") return v + @field_validator("rollout_stride") + @classmethod + def validate_rollout_stride(cls, v: float) -> float: + if v <= 0.0 or v > 1.0: + raise ValueError(f"rollout_stride must be in (0.0, 1.0], got {v}") + return v + @property def checkpoint_dir_variant(self) -> Path: if self.variant: diff --git a/ml/dataset/npz_sequence.py b/ml/dataset/npz_sequence.py index 4dd2775..5179d5c 100644 --- a/ml/dataset/npz_sequence.py +++ b/ml/dataset/npz_sequence.py @@ -52,27 +52,27 @@ def _load_sequence_metadata(path: Path) -> tuple[int, int, int]: return T, H, W -def _build_real_sequence_indices( - seq_paths: list[Path], rollout_steps: int = 1 -) -> tuple[list[tuple[int, int]], list[int], int, int]: - indices: list[tuple[int, int]] = [] +def _load_sequence_dimensions( + seq_paths: list[Path], +) -> tuple[list[int], int, int]: frame_counts: list[int] = [] h, w = 0, 0 - for si, path in enumerate(seq_paths): T, H, W = _load_sequence_metadata(path) - - # Store dimensions from first sequence if si == 0: h, w = H, W - frame_counts.append(T) + return frame_counts, h, w - # ! Ensure t+rollout_steps is valid - for t in range(1, T - rollout_steps): - indices.append((si, t)) - return indices, frame_counts, h, w +def _build_indices_for_offset( + frame_counts: list[int], rollout_steps: int, stride: int, offset: int = 0 +) -> list[tuple[int, int]]: + indices: list[tuple[int, int]] = [] + for si, T in enumerate(frame_counts): + for t in range(1 + offset, T - rollout_steps, stride): + indices.append((si, t)) + return indices def _build_fake_sequence_indices( @@ -231,6 +231,7 @@ def __init__( augmentation_config: dict | None = None, preload: bool = False, rollout_steps: int = 1, + stride: int = 1, ) -> None: self.npz_dir = npz_dir self.normalize = normalize @@ -263,7 +264,14 @@ def __init__( stats_path = PROJECT_ROOT_PATH / project_config.vdb_tools.stats_output_file self._norm_scales = load_normalization_scales(stats_path) - self._index, frame_counts, h, w = _build_real_sequence_indices(self.seq_paths, rollout_steps=rollout_steps) + frame_counts, h, w = _load_sequence_dimensions(self.seq_paths) + + self._indices_by_offset: list[list[tuple[int, int]]] = [] + for offset in range(stride): + indices = _build_indices_for_offset(frame_counts, rollout_steps, stride, offset) + self._indices_by_offset.append(indices) + + self._index = self._indices_by_offset[0] if not self._index: raise RuntimeError("No valid samples found (need T>=3 per sequence)") @@ -274,6 +282,10 @@ def __init__( if self.preload: self._preload_sequences() + def set_epoch(self, epoch: int) -> None: + offset = epoch % len(self._indices_by_offset) + self._index = self._indices_by_offset[offset] + def _estimate_memory_usage(self, seq_paths: list[Path]) -> tuple[int, str]: total_bytes = 0 diff --git a/ml/scripts/train.py b/ml/scripts/train.py index ef13395..719f10b 100644 --- a/ml/scripts/train.py +++ b/ml/scripts/train.py @@ -84,6 +84,8 @@ def train_single_variant( "flip_probability": config.augmentation.flip_probability, } + train_stride = max(1, round(config.rollout_stride * config.rollout_step)) + train_ds = FluidNPZSequenceDataset( npz_dir=npz_dir, normalize=config.normalize, @@ -92,6 +94,7 @@ def train_single_variant( augmentation_config=aug_config_dict, preload=config.preload_dataset, rollout_steps=config.rollout_step, + stride=train_stride, ) val_rollout_steps = config.rollout_step if config.validation_use_rollout_k else 1 @@ -207,6 +210,8 @@ def train_single_variant( "augmentation.flip_axis": config.augmentation.flip_axis, # Multi-step rollout training "rollout_step": config.rollout_step, + "rollout_stride": config.rollout_stride, + "train_stride": train_stride, "rollout_weight_decay": config.rollout_weight_decay, "rollout_gradient_truncation": config.rollout_gradient_truncation, "validation_use_rollout_k": config.validation_use_rollout_k, diff --git a/ml/training/trainer.py b/ml/training/trainer.py index 863d93b..319d24d 100644 --- a/ml/training/trainer.py +++ b/ml/training/trainer.py @@ -627,6 +627,8 @@ def train(self) -> None: for epoch in range(self.config.epochs): epoch_start = time.time() + self.train_loader.dataset.set_epoch(epoch) + train_losses = self.train_epoch() val_losses = self.validate() From 4ae9eddfe79553e2dc4539287ed0b6baa3a8ee85 Mon Sep 17 00:00:00 2001 From: ohmatheus Date: Tue, 17 Feb 2026 09:59:11 +0100 Subject: [PATCH 2/2] fixed type check --- ml/training/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml/training/trainer.py b/ml/training/trainer.py index 319d24d..9f1a4d5 100644 --- a/ml/training/trainer.py +++ b/ml/training/trainer.py @@ -627,7 +627,7 @@ def train(self) -> None: for epoch in range(self.config.epochs): epoch_start = time.time() - self.train_loader.dataset.set_epoch(epoch) + cast("FluidNPZSequenceDataset", self.train_loader.dataset).set_epoch(epoch) train_losses = self.train_epoch() val_losses = self.validate()