Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion the_well/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.2.0"
__version__ = "1.2.1"


__all__ = ["__version__"]
84 changes: 52 additions & 32 deletions the_well/benchmark/trainer/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
self.best_val_loss = None
self.starting_val_loss = float("inf")
self.dset_metadata = self.datamodule.train_dataset.metadata
self.dset_norm = None
if self.datamodule.train_dataset.use_normalization:
self.dset_norm = self.datamodule.train_dataset.norm
if formatter == "channels_first_default":
Expand Down Expand Up @@ -176,38 +177,51 @@ def load_checkpoint(self, checkpoint_path: str):
checkpoint["epoch"] + 1
) # Saves after training loop, so start at next epoch

def normalize(self, batch):
def normalize(self, batch_dict=None, direct_tensor=None):
if hasattr(self, "dset_norm") and self.dset_norm:
batch["input_fields"] = self.dset_norm.normalize_flattened(
batch["input_fields"], "variable"
)
if "constant_fields" in batch:
batch["constant_fields"] = self.dset_norm.normalize_flattened(
batch["constant_fields"], "constant"
if batch_dict is not None:
batch_dict["input_fields"] = self.dset_norm.normalize_flattened(
batch_dict["input_fields"], "variable"
)
return batch
if "constant_fields" in batch_dict:
batch_dict["constant_fields"] = self.dset_norm.normalize_flattened(
batch_dict["constant_fields"], "constant"
)
if direct_tensor is not None:
if self.is_delta:
direct_tensor = self.dset_norm.normalize_delta_flattened(
direct_tensor, "variable"
)
else:
direct_tensor = self.dset_norm.normalize_flattened(
direct_tensor, "variable"
)
return batch_dict, direct_tensor

def denormalize(self, batch, prediction):
def denormalize(self, batch_dict=None, direct_tensor=None):
if hasattr(self, "dset_norm") and self.dset_norm:
batch["input_fields"] = self.dset_norm.denormalize_flattened(
batch["input_fields"], "variable"
)
if "constant_fields" in batch:
batch["constant_fields"] = self.dset_norm.denormalize_flattened(
batch["constant_fields"], "constant"
)

# Delta denormalization is different than full denormalization
if self.is_delta:
prediction = self.dset_norm.delta_denormalize_flattened(
prediction, "variable"
)
else:
prediction = self.dset_norm.denormalize_flattened(
prediction, "variable"
if batch_dict is not None:
batch_dict["input_fields"] = self.dset_norm.denormalize_flattened(
batch_dict["input_fields"], "variable"
)
if "constant_fields" in batch_dict:
batch_dict["constant_fields"] = (
self.dset_norm.denormalize_flattened(
batch_dict["constant_fields"], "constant"
)
)
if direct_tensor is not None:
# Delta denormalization is different than full denormalization
if self.is_delta:
direct_tensor = self.dset_norm.delta_denormalize_flattened(
direct_tensor, "variable"
)
else:
direct_tensor = self.dset_norm.denormalize_flattened(
direct_tensor, "variable"
)

return batch, prediction
return batch_dict, direct_tensor

def rollout_model(self, model, batch, formatter, train=True):
"""Rollout the model for as many steps as we have data for."""
Expand All @@ -216,31 +230,37 @@ def rollout_model(self, model, batch, formatter, train=True):
y_ref.shape[1], self.max_rollout_steps
) # Number of timesteps in target
y_ref = y_ref[:, :rollout_steps]
# NOTE: This is a quick fix so we can make datamodule behavior consistent. Revisit this next release (MM).
if not train:
_, y_ref = self.denormalize(None, y_ref)

# Create a moving batch of one step at a time
moving_batch = batch
moving_batch = dict(batch)
moving_batch["input_fields"] = moving_batch["input_fields"].to(self.device)
if "constant_fields" in moving_batch:
moving_batch["constant_fields"] = moving_batch["constant_fields"].to(
self.device
)
y_preds = []
for i in range(rollout_steps):
if not train:
moving_batch = self.normalize(moving_batch)
# NOTE: This is a quick fix so we can make datamodule behavior consistent.
# Including local normalization schemes means there needs to be the option of normalizing each step
# and there's currently not a registry of local vs global normalization schemes.
if not train and self.datamodule.val_dataset.use_normalization and i > 0:
moving_batch, _ = self.normalize(moving_batch)

inputs, _ = formatter.process_input(moving_batch)
inputs = [x.to(self.device) for x in inputs]
y_pred = model(*inputs)

y_pred = formatter.process_output_channel_last(y_pred)

if not train:
moving_batch, y_pred = self.denormalize(moving_batch, y_pred)

if (not train) and self.is_delta:
assert {
assert (
moving_batch["input_fields"][:, -1, ...].shape == y_pred.shape
}, f"Mismatching shapes between last input timestep {moving_batch[:, -1, ...].shape}\
), f"Mismatching shapes between last input timestep {moving_batch[:, -1, ...].shape}\
and prediction {y_pred.shape}"
y_pred = moving_batch["input_fields"][:, -1, ...] + y_pred
y_pred = formatter.process_output_expand_time(y_pred)
Expand Down
8 changes: 8 additions & 0 deletions the_well/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def __init__(
well_split_name="valid",
include_filters=include_filters,
exclude_filters=exclude_filters,
use_normalization=use_normalization,
normalization_type=normalization_type,
n_steps_input=n_steps_input,
n_steps_output=n_steps_output,
storage_options=storage_kwargs,
Expand All @@ -181,6 +183,8 @@ def __init__(
well_split_name="valid",
include_filters=include_filters,
exclude_filters=exclude_filters,
use_normalization=use_normalization,
normalization_type=normalization_type,
max_rollout_steps=max_rollout_steps,
n_steps_input=n_steps_input,
n_steps_output=n_steps_output,
Expand All @@ -201,6 +205,8 @@ def __init__(
well_split_name="test",
include_filters=include_filters,
exclude_filters=exclude_filters,
use_normalization=use_normalization,
normalization_type=normalization_type,
n_steps_input=n_steps_input,
n_steps_output=n_steps_output,
storage_options=storage_kwargs,
Expand All @@ -219,6 +225,8 @@ def __init__(
well_split_name="test",
include_filters=include_filters,
exclude_filters=exclude_filters,
use_normalization=use_normalization,
normalization_type=normalization_type,
max_rollout_steps=max_rollout_steps,
n_steps_input=n_steps_input,
n_steps_output=n_steps_output,
Expand Down
Loading