From 665fb2b0c68f9e2ca4ddb62bcf4dccb10e9fae58 Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Wed, 19 Nov 2025 00:48:59 -0500 Subject: [PATCH 1/2] Reduce augmentation probabilities to fix underfitting --- XPointMLTest.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index 3027485..d56e05e 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -381,8 +381,8 @@ def _apply_augmentation(self, all_data, mask): return all_data, mask # 1. Random rotation (0, 90, 180, 270 degrees) - # 75% chance to apply rotation - if self.rng.random() < 0.75: + # 50% chance to apply rotation + if self.rng.random() < 0.50: k = self.rng.integers(1, 4) # 1, 2, or 3 (90°, 180°, 270°) all_data = torch.rot90(all_data, k=k, dims=(-2, -1)) mask = torch.rot90(mask, k=k, dims=(-2, -1)) @@ -397,9 +397,9 @@ def _apply_augmentation(self, all_data, mask): all_data = torch.flip(all_data, dims=(-2,)) mask = torch.flip(mask, dims=(-2,)) - # 4. Add Gaussian noise (30% chance) + # 4. Add Gaussian noise (10% chance) # Small noise helps prevent overfitting to exact pixel values - if self.rng.random() < 0.3: + if self.rng.random() < 0.1: noise_std = self.rng.uniform(0.005, 0.02) noise = torch.randn_like(all_data) * noise_std all_data = all_data + noise @@ -413,9 +413,9 @@ def _apply_augmentation(self, all_data, mask): mean = all_data[c].mean() all_data[c] = contrast * (all_data[c] - mean) + mean + brightness - # 6. Cutout/Random erasing (20% chance) + # 6. Cutout/Random erasing (5% chance) # Prevents model from relying too heavily on specific spatial features - if self.rng.random() < 0.2: + if self.rng.random() < 0.05: h, w = all_data.shape[-2:] cutout_size = int(min(h, w) * self.rng.uniform(0.1, 0.25)) if cutout_size > 0: From 9fbdff2fe0ae2a16cd78137e8e17068909be94dc Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Tue, 2 Dec 2025 12:44:18 -0500 Subject: [PATCH 2/2] Fix augmentation bug: apply brightness/contrast globally to preserve physical field relationships --- XPointMLTest.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index d56e05e..d50e355 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -404,14 +404,15 @@ def _apply_augmentation(self, all_data, mask): noise = torch.randn_like(all_data) * noise_std all_data = all_data + noise - # 5. Random brightness/contrast adjustment per channel (30% chance) - # Helps model become invariant to intensity variations + # 5. Random brightness/contrast adjustment (30% chance) + # CHANGED: Applied globally across channels to preserve physical relationships + # (e.g., keeping the derivative relationship between psi and B fields) if self.rng.random() < 0.3: - for c in range(all_data.shape[0]): - brightness = self.rng.uniform(-0.1, 0.1) - contrast = self.rng.uniform(0.9, 1.1) - mean = all_data[c].mean() - all_data[c] = contrast * (all_data[c] - mean) + mean + brightness + brightness = self.rng.uniform(-0.1, 0.1) + contrast = self.rng.uniform(0.9, 1.1) + # Apply same transformation to all channels + mean = all_data.mean(dim=(-2, -1), keepdim=True) + all_data = contrast * (all_data - mean) + mean + brightness # 6. Cutout/Random erasing (5% chance) # Prevents model from relying too heavily on specific spatial features