Skip to content
19 changes: 8 additions & 11 deletions ml/dataset/npz_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ class FluidNPZSequenceDataset(Dataset):
def __init__(
self,
npz_dir: str | Path,
split: str,
normalize: bool = False,
seq_indices: list[int] | None = None,
is_training: bool = False,
augmentation_config: dict | None = None,
preload: bool = False,
Expand All @@ -242,18 +242,15 @@ def __init__(
self.enable_augmentation = self.augmentation_config.get("enable_augmentation", False)
self.flip_probability = self.augmentation_config.get("flip_probability", 0.5)

npz_dir_path = Path(npz_dir)
all_seq_paths: list[Path] = sorted(
npz_dir_path = Path(npz_dir) / split
if not npz_dir_path.exists():
raise FileNotFoundError(f"Split directory not found: {npz_dir_path}")

self.seq_paths: list[Path] = sorted(
[f for f in npz_dir_path.iterdir() if f.name.startswith("seq_") and f.name.endswith(".npz")]
)
if not all_seq_paths:
raise FileNotFoundError(f"No seq_*.npz files found in {npz_dir}")

# Filter to specified sequences if provided (splits)
if seq_indices is not None:
self.seq_paths = [all_seq_paths[i] for i in seq_indices]
else:
self.seq_paths = all_seq_paths
if not self.seq_paths:
raise FileNotFoundError(f"No seq_*.npz files found in {npz_dir_path}")

self.num_real_sequences = len(self.seq_paths)
# self.num_fake_sequences = _calculate_fake_count(self.num_real_sequences, fake_empty_pct)
Expand Down
21 changes: 6 additions & 15 deletions ml/scripts/eval_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

import yaml

from config.config import PROJECT_ROOT_PATH
from config.training_config import TrainingConfig, project_config
Expand All @@ -20,7 +19,6 @@
_plot_rollout_degradation,
_run_single_rollout,
)
from utils.data_splits import make_splits, set_seed

ML_ROOT = Path(__file__).resolve().parent.parent
DATA_DIR = ML_ROOT.parent / "data" / "npz" / "128"
Expand Down Expand Up @@ -65,12 +63,14 @@ def load_from_checkpoint(checkpoint_dir: Path, device: str) -> tuple[TrainingCon
def run_rollout_to_disk(
model: torch.nn.Module,
config: TrainingConfig,
test_indices: list[int],
device: str,
output_name: str,
) -> None:
all_seq_paths = sorted([p for p in DATA_DIR.iterdir() if p.name.startswith("seq_") and p.name.endswith(".npz")])
test_paths = [all_seq_paths[i] for i in test_indices]
test_data_dir = DATA_DIR / "test"
if not test_data_dir.exists():
raise FileNotFoundError(f"Test split directory not found: {test_data_dir}")

test_paths = sorted([p for p in test_data_dir.iterdir() if p.name.startswith("seq_") and p.name.endswith(".npz")])

norm_scales = None
if config.normalize:
Expand Down Expand Up @@ -144,17 +144,8 @@ def main() -> None:

config, model = load_from_checkpoint(checkpoint_dir, device)

with open(BASE_CONFIG_PATH) as f:
base_cfg = yaml.safe_load(f)
split_ratios = tuple(base_cfg["split_ratios"])
split_seed = base_cfg["split_seed"]

set_seed(split_seed)
_, _, test_idx = make_splits(DATA_DIR, split_ratios, split_seed)
print(f"Test sequences: {len(test_idx)}")

output_name = args.output_name or checkpoint_dir.name
run_rollout_to_disk(model, config, test_idx, device, output_name)
run_rollout_to_disk(model, config, device, output_name)


if __name__ == "__main__":
Expand Down
34 changes: 13 additions & 21 deletions ml/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from scripts.variant_manager import VariantManager
from training.test_evaluation import run_rollout_evaluation, run_test_evaluation
from training.trainer import Trainer
from utils.data_splits import make_splits, set_seed
from utils.seed import set_seed


def dict_to_training_config(config_dict: dict) -> TrainingConfig:
Expand Down Expand Up @@ -75,9 +75,6 @@ def train_single_variant(
set_seed(config.split_seed)

npz_dir: Path = config.npz_dir / str(project_config.simulation.grid_resolution)
train_idx, val_idx, test_idx = make_splits(npz_dir, config.split_ratios, config.split_seed)

print(f"Dataset splits: train={len(train_idx)}, val={len(val_idx)}, test={len(test_idx)}")

aug_config_dict = {
"enable_augmentation": config.augmentation.enable_augmentation,
Expand All @@ -88,8 +85,8 @@ def train_single_variant(

train_ds = FluidNPZSequenceDataset(
npz_dir=npz_dir,
split="train",
normalize=config.normalize,
seq_indices=train_idx,
is_training=True,
augmentation_config=aug_config_dict,
preload=config.preload_dataset,
Expand All @@ -100,8 +97,8 @@ def train_single_variant(
val_rollout_steps = config.rollout_step if config.validation_use_rollout_k else 1
val_ds = FluidNPZSequenceDataset(
npz_dir=npz_dir,
split="val",
normalize=config.normalize,
seq_indices=val_idx,
is_training=False,
augmentation_config=None,
preload=config.preload_dataset,
Expand Down Expand Up @@ -243,25 +240,20 @@ def train_single_variant(
val_loader=val_loader,
config=config,
device=config.device,
train_indices=train_idx,
val_indices=val_idx,
)

trainer.train()

if test_idx:
run_test_evaluation(
model=model,
config=config,
test_indices=test_idx,
device=config.device,
)
run_rollout_evaluation(
model=model,
config=config,
test_indices=test_idx,
device=config.device,
)
run_test_evaluation(
model=model,
config=config,
device=config.device,
)
run_rollout_evaluation(
model=model,
config=config,
device=config.device,
)

print(f"\nTraining complete: {config.variant.full_model_name}")
print(f"Best model: {config.checkpoint_dir_variant / 'best_model.pth'}")
Expand Down
21 changes: 8 additions & 13 deletions ml/training/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,19 @@ def _compute_batch_metrics(
def run_test_evaluation(
model: nn.Module,
config: TrainingConfig,
test_indices: list[int],
device: str,
) -> dict[str, dict[str, float]]:
if not test_indices:
print("No test indices provided, skipping test evaluation.")
return {}

npz_dir = config.npz_dir / str(project_config.simulation.grid_resolution)

print(f"\n{'=' * 70}")
print("TEST SET EVALUATION (best model + persistence baseline)")
print(f"{'=' * 70}")
print(f"Loading test dataset ({len(test_indices)} sequences)...")
print("Loading test dataset...")

test_ds = FluidNPZSequenceDataset(
npz_dir=npz_dir,
split="test",
normalize=config.normalize,
seq_indices=test_indices,
is_training=False,
augmentation_config=None,
preload=config.preload_dataset,
Expand Down Expand Up @@ -369,16 +364,16 @@ def log_artifact_flat(fig: Figure, filename: str, dpi: int = 72, artifact_path:
def run_rollout_evaluation(
model: nn.Module,
config: TrainingConfig,
test_indices: list[int],
device: str,
) -> dict[str, list[dict[str, float]]]:
if not test_indices:
print("No test indices provided, skipping rollout evaluation.")
npz_dir = config.npz_dir / str(project_config.simulation.grid_resolution)
test_npz_dir = npz_dir / "test"

if not test_npz_dir.exists():
print(f"Test split directory not found: {test_npz_dir}, skipping rollout evaluation.")
return {}

npz_dir = config.npz_dir / str(project_config.simulation.grid_resolution)
all_seq_paths = sorted([p for p in npz_dir.iterdir() if p.name.startswith("seq_") and p.name.endswith(".npz")])
test_paths = [all_seq_paths[i] for i in test_indices]
test_paths = sorted([p for p in test_npz_dir.iterdir() if p.name.startswith("seq_") and p.name.endswith(".npz")])

norm_scales = None
if config.normalize:
Expand Down
17 changes: 2 additions & 15 deletions ml/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,13 @@ def __init__(
val_loader: DataLoader,
config: TrainingConfig,
device: str,
train_indices: list[int] | None = None,
val_indices: list[int] | None = None,
) -> None:
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config.model_copy()
self.device = device

# Store indices for dataloader recreation during scheduled rollout
self.train_indices = train_indices
self.val_indices = val_indices

self.gradient_clip_norm = config.gradient_clip_norm
self.gradient_clip_enabled = config.gradient_clip_enabled

Expand Down Expand Up @@ -391,14 +385,7 @@ def _compute_rollout_loss(
return total_loss, averaged_dict, final_pred

def _create_dataloader_with_rollout_steps(self, K: int, is_training: bool) -> DataLoader:
if is_training:
if self.train_indices is None:
raise RuntimeError("Cannot recreate dataloader: train_indices not provided to Trainer")
indices = self.train_indices
else:
if self.val_indices is None:
raise RuntimeError("Cannot recreate dataloader: val_indices not provided to Trainer")
indices = self.val_indices
split = "train" if is_training else "val"

npz_dir = (
PROJECT_ROOT_PATH
Expand All @@ -408,8 +395,8 @@ def _create_dataloader_with_rollout_steps(self, K: int, is_training: bool) -> Da

dataset = FluidNPZSequenceDataset(
npz_dir=npz_dir,
split=split,
normalize=self.config.normalize,
seq_indices=indices,
is_training=is_training,
augmentation_config=self.config.augmentation.model_dump() if is_training else None,
preload=self.config.preload_dataset,
Expand Down
54 changes: 0 additions & 54 deletions ml/utils/data_splits.py

This file was deleted.

14 changes: 14 additions & 0 deletions ml/utils/seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import random

import numpy as np
import torch


def set_seed(seed: int) -> None:
"""Set random seeds for reproducibility across Python, NumPy, and PyTorch."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
Loading