Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ml/config/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ physics_loss:
# NOTE: rollout_step is FIXED per variant (K1_A=1, K2_A=2, K3_A=3, etc.)
# No rollout_schedule needed - each K variant trains at fixed K
rollout_step: 0
rollout_stride: 1
rollout_weight_decay: 1.10
rollout_gradient_truncation: false
validation_use_rollout_k: true
Expand Down
8 changes: 8 additions & 0 deletions ml/config/training_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class TrainingConfig(BaseModel):

# Multi-step rollout training
rollout_step: int = 0
rollout_stride: float = 1.0 # fraction of K used as stride (0.0–1.0], 1.0 = stride equals K
rollout_weight_decay: float = 1.10 # not used if rollout_final_step_only is True
# ∂total_loss/∂θ =
# w₀/sum × ∂loss₀/∂θ [direct from step 0]
Expand All @@ -120,6 +121,13 @@ def validate_rollout_step(cls, v: int) -> int:
raise ValueError(f"rollout_step must be >= 0, got {v}")
return v

@field_validator("rollout_stride")
@classmethod
def validate_rollout_stride(cls, v: float) -> float:
if v <= 0.0 or v > 1.0:
raise ValueError(f"rollout_stride must be in (0.0, 1.0], got {v}")
return v

@property
def checkpoint_dir_variant(self) -> Path:
if self.variant:
Expand Down
38 changes: 25 additions & 13 deletions ml/dataset/npz_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,27 @@ def _load_sequence_metadata(path: Path) -> tuple[int, int, int]:
return T, H, W


def _build_real_sequence_indices(
seq_paths: list[Path], rollout_steps: int = 1
) -> tuple[list[tuple[int, int]], list[int], int, int]:
indices: list[tuple[int, int]] = []
def _load_sequence_dimensions(
seq_paths: list[Path],
) -> tuple[list[int], int, int]:
frame_counts: list[int] = []
h, w = 0, 0

for si, path in enumerate(seq_paths):
T, H, W = _load_sequence_metadata(path)

# Store dimensions from first sequence
if si == 0:
h, w = H, W

frame_counts.append(T)
return frame_counts, h, w

# ! Ensure t+rollout_steps is valid
for t in range(1, T - rollout_steps):
indices.append((si, t))

return indices, frame_counts, h, w
def _build_indices_for_offset(
frame_counts: list[int], rollout_steps: int, stride: int, offset: int = 0
) -> list[tuple[int, int]]:
indices: list[tuple[int, int]] = []
for si, T in enumerate(frame_counts):
for t in range(1 + offset, T - rollout_steps, stride):
indices.append((si, t))
return indices


def _build_fake_sequence_indices(
Expand Down Expand Up @@ -231,6 +231,7 @@ def __init__(
augmentation_config: dict | None = None,
preload: bool = False,
rollout_steps: int = 1,
stride: int = 1,
) -> None:
self.npz_dir = npz_dir
self.normalize = normalize
Expand Down Expand Up @@ -263,7 +264,14 @@ def __init__(
stats_path = PROJECT_ROOT_PATH / project_config.vdb_tools.stats_output_file
self._norm_scales = load_normalization_scales(stats_path)

self._index, frame_counts, h, w = _build_real_sequence_indices(self.seq_paths, rollout_steps=rollout_steps)
frame_counts, h, w = _load_sequence_dimensions(self.seq_paths)

self._indices_by_offset: list[list[tuple[int, int]]] = []
for offset in range(stride):
indices = _build_indices_for_offset(frame_counts, rollout_steps, stride, offset)
self._indices_by_offset.append(indices)

self._index = self._indices_by_offset[0]

if not self._index:
raise RuntimeError("No valid samples found (need T>=3 per sequence)")
Expand All @@ -274,6 +282,10 @@ def __init__(
if self.preload:
self._preload_sequences()

def set_epoch(self, epoch: int) -> None:
offset = epoch % len(self._indices_by_offset)
self._index = self._indices_by_offset[offset]

def _estimate_memory_usage(self, seq_paths: list[Path]) -> tuple[int, str]:
total_bytes = 0

Expand Down
5 changes: 5 additions & 0 deletions ml/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def train_single_variant(
"flip_probability": config.augmentation.flip_probability,
}

train_stride = max(1, round(config.rollout_stride * config.rollout_step))

train_ds = FluidNPZSequenceDataset(
npz_dir=npz_dir,
normalize=config.normalize,
Expand All @@ -92,6 +94,7 @@ def train_single_variant(
augmentation_config=aug_config_dict,
preload=config.preload_dataset,
rollout_steps=config.rollout_step,
stride=train_stride,
)

val_rollout_steps = config.rollout_step if config.validation_use_rollout_k else 1
Expand Down Expand Up @@ -207,6 +210,8 @@ def train_single_variant(
"augmentation.flip_axis": config.augmentation.flip_axis,
# Multi-step rollout training
"rollout_step": config.rollout_step,
"rollout_stride": config.rollout_stride,
"train_stride": train_stride,
"rollout_weight_decay": config.rollout_weight_decay,
"rollout_gradient_truncation": config.rollout_gradient_truncation,
"validation_use_rollout_k": config.validation_use_rollout_k,
Expand Down
2 changes: 2 additions & 0 deletions ml/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,8 @@ def train(self) -> None:
for epoch in range(self.config.epochs):
epoch_start = time.time()

cast("FluidNPZSequenceDataset", self.train_loader.dataset).set_epoch(epoch)

train_losses = self.train_epoch()
val_losses = self.validate()

Expand Down