Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions fme/ace/data_loading/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,16 @@ def build(self):
requirements=DataRequirements(self.names, self.n_timesteps),
)

def configure_wandb(self, env_vars: dict | None = None, **kwargs):
config = to_flat_dict(dataclasses.asdict(self))
# our wandb class requires "experiment_dir" to be in config
config["experiment_dir"] = TMPDIR
os.makedirs(TMPDIR)
self.logging.configure_wandb(config=config, env_vars=env_vars, **kwargs)

def configure_logging(self):
self.logging.configure_logging("/tmp", "log.txt")
config = to_flat_dict(dataclasses.asdict(self))
os.makedirs(TMPDIR, exist_ok=True)
self.logging.configure_logging(
TMPDIR, "log.txt", config=config, resumable=False
)


def benchmark(config: BenchmarkConfig):
config.configure_logging()
env_vars = logging_utils.retrieve_env_vars()
beaker_url = logging_utils.log_beaker_url()
config.configure_wandb(env_vars=env_vars, notes=beaker_url)
wandb = WandB.get_instance()

with GlobalTimer():
Expand Down
16 changes: 2 additions & 14 deletions fme/ace/inference/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch

import fme
import fme.core.logging_utils as logging_utils
from fme.ace.aggregator.inference import InferenceEvaluatorAggregatorConfig
from fme.ace.data_loading.batch_data import BatchData
from fme.ace.data_loading.getters import get_inference_data
Expand Down Expand Up @@ -162,14 +161,9 @@ def __post_init__(self):
)

def configure_logging(self, log_filename: str):
self.logging.configure_logging(self.experiment_dir, log_filename)

def configure_wandb(
self, env_vars: dict | None = None, resumable: bool = False, **kwargs
):
config = to_flat_dict(dataclasses.asdict(self))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should remove to_flat_dict here (and other inference entrypoints) as was done in the training scripts?

self.logging.configure_wandb(
config=config, env_vars=env_vars, resumable=resumable, **kwargs
self.logging.configure_logging(
self.experiment_dir, log_filename, config=config, resumable=False
)

def load_stepper(self) -> Stepper:
Expand Down Expand Up @@ -246,16 +240,10 @@ def run_evaluator_from_config(config: InferenceEvaluatorConfig):

makedirs(config.experiment_dir, exist_ok=True)
config.configure_logging(log_filename="inference_out.log")
env_vars = logging_utils.retrieve_env_vars()
beaker_url = logging_utils.log_beaker_url()
config.configure_wandb(env_vars=env_vars, notes=beaker_url)

if fme.using_gpu():
torch.backends.cudnn.benchmark = True

logging_utils.log_versions()
logging.info(f"Current device is {fme.get_device()}")

stepper_config = config.load_stepper_config()
logging.info("Initializing data loader")
window_requirements = stepper_config.get_evaluation_window_data_requirements(
Expand Down
16 changes: 2 additions & 14 deletions fme/ace/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from xarray.coding.times import CFDatetimeCoder

import fme
import fme.core.logging_utils as logging_utils
from fme.ace.aggregator.inference import InferenceAggregatorConfig
from fme.ace.data_loading.batch_data import BatchData, PrognosticState
from fme.ace.data_loading.getters import get_forcing_data
Expand Down Expand Up @@ -235,14 +234,9 @@ def __post_init__(self):
)

def configure_logging(self, log_filename: str):
self.logging.configure_logging(self.experiment_dir, log_filename)

def configure_wandb(
self, env_vars: dict | None = None, resumable: bool = False, **kwargs
):
config = to_flat_dict(dataclasses.asdict(self))
self.logging.configure_wandb(
config=config, env_vars=env_vars, resumable=resumable, **kwargs
self.logging.configure_logging(
self.experiment_dir, log_filename, config=config, resumable=False
)

def load_stepper(self) -> Stepper:
Expand Down Expand Up @@ -299,16 +293,10 @@ def run_inference_from_config(config: InferenceConfig):
with timer.context("initialization"):
makedirs(config.experiment_dir, exist_ok=True)
config.configure_logging(log_filename="inference_out.log")
env_vars = logging_utils.retrieve_env_vars()
beaker_url = logging_utils.log_beaker_url()
config.configure_wandb(env_vars=env_vars, notes=beaker_url)

if fme.using_gpu():
torch.backends.cudnn.benchmark = True

logging_utils.log_versions()
logging.info(f"Current device is {fme.get_device()}")

stepper_config = config.load_stepper_config()
data_requirements = stepper_config.get_forcing_window_data_requirements(
n_forward_steps=config.forward_steps_in_memory
Expand Down
17 changes: 6 additions & 11 deletions fme/ace/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import xarray as xr

import fme
import fme.core.logging_utils as logging_utils
from fme.ace.aggregator import (
OneStepAggregator,
OneStepAggregatorConfig,
Expand All @@ -75,7 +74,6 @@
from fme.core.cli import prepare_config, prepare_directory
from fme.core.dataset_info import DatasetInfo
from fme.core.derived_variables import get_derived_variable_metadata
from fme.core.dicts import to_flat_dict
from fme.core.distributed import Distributed
from fme.core.generics.data import InferenceDataABC
from fme.core.generics.trainer import (
Expand Down Expand Up @@ -254,15 +252,12 @@ def run_train(builders: TrainBuilders, config: TrainConfig):
torch.backends.cudnn.benchmark = True
if not os.path.isdir(config.experiment_dir):
os.makedirs(config.experiment_dir, exist_ok=True)
config.logging.configure_logging(config.experiment_dir, log_filename="out.log")
env_vars = logging_utils.retrieve_env_vars()
logging_utils.log_versions()
beaker_url = logging_utils.log_beaker_url()
config_as_dict = to_flat_dict(dataclasses.asdict(config))
config.logging.configure_wandb(
config=config_as_dict,
env_vars=env_vars,
notes=beaker_url,
config_data = dataclasses.asdict(config)
config.logging.configure_logging(
config.experiment_dir,
log_filename="out.log",
config=config_data,
resumable=True,
)
if config.resume_results is not None:
logging.info(
Expand Down
12 changes: 9 additions & 3 deletions fme/core/generics/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,9 @@ def _get_trainer(train_losses, val_losses, inference_errors):
return trainer

with mock_wandb() as wandb:
LoggingConfig(log_to_wandb=True).configure_wandb({"experiment_dir": tmp_path})
LoggingConfig(log_to_wandb=True)._configure_wandb(
experiment_dir=tmp_path, config={}, resumable=True
)
# run training in two segments to ensure coverage of check that extra validation
# really only happens before any training is done.
trainer = _get_trainer(train_losses[:1], val_losses[:2], inference_errors[:2])
Expand Down Expand Up @@ -1066,7 +1068,9 @@ def _get_trainer(train_losses, val_losses, inference_errors):
return trainer

with mock_wandb() as wandb:
LoggingConfig(log_to_wandb=True).configure_wandb({"experiment_dir": tmp_path})
LoggingConfig(log_to_wandb=True)._configure_wandb(
experiment_dir=tmp_path, config={}, resumable=True
)
trainer = _get_trainer(train_losses, val_losses, inference_errors)
trainer.train()
wandb_logs = wandb.get_logs()
Expand Down Expand Up @@ -1105,7 +1109,9 @@ def _get_trainer(train_losses, val_losses, inference_errors):
return trainer

with mock_wandb() as wandb:
LoggingConfig(log_to_wandb=True).configure_wandb({"experiment_dir": tmp_path})
LoggingConfig(log_to_wandb=True)._configure_wandb(
experiment_dir=tmp_path, config={}, resumable=True
)
trainer = _get_trainer(train_losses, val_losses, inference_errors)
trainer.train()
wandb_logs = wandb.get_logs()
Expand Down
60 changes: 49 additions & 11 deletions fme/core/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any

from fme.core.cloud import is_local
from fme.core.device import get_device
from fme.core.distributed import Distributed
from fme.core.wandb import WandB

Expand Down Expand Up @@ -54,7 +55,34 @@ class LoggingConfig:
def __post_init__(self):
self._dist = Distributed.get_instance()

def configure_logging(self, experiment_dir: str, log_filename: str):
def configure_logging(
self,
experiment_dir: str,
log_filename: str,
config: Mapping[str, Any],
resumable: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value for resumable was changed from True to False. Please double check that resumable=True is explicitly passed everywhere configure_wandb was previously called without the resumable argument.

):
"""
Configure global logging settings, including WandB, and output
initial logs of the runtime environment.

Args:
experiment_dir: Directory to save logs to.
log_filename: Name of the log file.
config: Configuration dictionary to log to WandB.
resumable: Whether this is a resumable run.
"""
self._configure_logging_module(experiment_dir, log_filename)
log_versions()
log_beaker_url()
self._configure_wandb(
experiment_dir=experiment_dir,
config=config,
resumable=resumable,
)
logging.info(f"Current device is {get_device()}")

def _configure_logging_module(self, experiment_dir: str, log_filename: str):
"""
Configure the global `logging` module based on this LoggingConfig.
"""
Expand Down Expand Up @@ -85,14 +113,15 @@ def configure_logging(self, experiment_dir: str, log_filename: str):
fh.setFormatter(logging.Formatter(self.log_format))
logger.addHandler(fh)

def configure_wandb(
def _configure_wandb(
self,
experiment_dir: str,
config: Mapping[str, Any],
env_vars: Mapping[str, Any] | None = None,
resumable: bool = True,
resume: Any = None,
**kwargs,
):
env_vars = retrieve_env_vars()
if resume is not None:
raise ValueError(
"The 'resume' argument is no longer supported, "
Expand All @@ -107,7 +136,6 @@ def configure_wandb(
elif env_vars is not None:
config_copy["environment"] = env_vars

experiment_dir = config["experiment_dir"]
if self.wandb_dir_in_experiment_dir:
wandb_dir = experiment_dir
else:
Expand All @@ -123,10 +151,24 @@ def configure_wandb(
experiment_dir=experiment_dir,
resumable=resumable,
dir=wandb_dir,
**kwargs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

**kwargs is not longer used. Do we need it?

notes=_get_beaker_url(_get_beaker_id()),
)


def _get_beaker_id() -> str | None:
try:
return os.environ["BEAKER_EXPERIMENT_ID"]
except KeyError:
logging.warning("Beaker Experiment ID not found.")
return None


def _get_beaker_url(beaker_id: str | None) -> str:
if beaker_id is None:
return "No beaker URL."
return f"https://beaker.org/ex/{beaker_id}"


def log_versions():
import torch

Expand Down Expand Up @@ -158,13 +200,9 @@ def log_beaker_url(beaker_id=None):
Returns the Beaker URL.
"""
if beaker_id is None:
try:
beaker_id = os.environ["BEAKER_EXPERIMENT_ID"]
except KeyError:
logging.warning("Beaker Experiment ID not found.")
return None
beaker_id = _get_beaker_id()

beaker_url = f"https://beaker.org/ex/{beaker_id}"
beaker_url = _get_beaker_url(beaker_id)
logging.info(f"Beaker ID: {beaker_id}")
logging.info(f"Beaker URL: {beaker_url}")
return beaker_url
Expand Down
16 changes: 2 additions & 14 deletions fme/coupled/inference/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch

import fme
import fme.core.logging_utils as logging_utils
from fme.ace.stepper import load_stepper as load_single_stepper
from fme.ace.stepper import load_stepper_config as load_single_stepper_config
from fme.core.cli import prepare_config, prepare_directory
Expand Down Expand Up @@ -189,14 +188,9 @@ class InferenceEvaluatorConfig:
)

def configure_logging(self, log_filename: str):
self.logging.configure_logging(self.experiment_dir, log_filename)

def configure_wandb(
self, env_vars: dict | None = None, resumable: bool = False, **kwargs
):
config = to_flat_dict(dataclasses.asdict(self))
self.logging.configure_wandb(
config=config, env_vars=env_vars, resumable=resumable, **kwargs
self.logging.configure_logging(
self.experiment_dir, log_filename, config=config, resumable=False
)

def load_stepper(self) -> CoupledStepper:
Expand Down Expand Up @@ -268,16 +262,10 @@ def run_evaluator_from_config(config: InferenceEvaluatorConfig):

makedirs(config.experiment_dir, exist_ok=True)
config.configure_logging(log_filename="inference_out.log")
env_vars = logging_utils.retrieve_env_vars()
beaker_url = logging_utils.log_beaker_url()
config.configure_wandb(env_vars=env_vars, notes=beaker_url)

if fme.using_gpu():
torch.backends.cudnn.benchmark = True

logging_utils.log_versions()
logging.info(f"Current device is {fme.get_device()}")

stepper_config = config.load_stepper_config()
logging.info("Loading inference data")
window_requirements = stepper_config.get_evaluation_window_data_requirements(
Expand Down
16 changes: 2 additions & 14 deletions fme/coupled/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import xarray as xr

import fme
import fme.core.logging_utils as logging_utils
from fme.ace.data_loading.inference import (
ExplicitIndices,
InferenceInitialConditionIndices,
Expand Down Expand Up @@ -143,14 +142,9 @@ class InferenceConfig:
n_ensemble_per_ic: int = 1

def configure_logging(self, log_filename: str):
self.logging.configure_logging(self.experiment_dir, log_filename)

def configure_wandb(
self, env_vars: dict | None = None, resumable: bool = False, **kwargs
):
config = to_flat_dict(dataclasses.asdict(self))
self.logging.configure_wandb(
config=config, env_vars=env_vars, resumable=resumable, **kwargs
self.logging.configure_logging(
self.experiment_dir, log_filename, config=config, resumable=False
)

def load_stepper(self) -> CoupledStepper:
Expand Down Expand Up @@ -227,16 +221,10 @@ def run_inference_from_config(config: InferenceConfig):

makedirs(config.experiment_dir, exist_ok=True)
config.configure_logging(log_filename="inference_out.log")
env_vars = logging_utils.retrieve_env_vars()
beaker_url = logging_utils.log_beaker_url()
config.configure_wandb(env_vars=env_vars, notes=beaker_url)

if fme.using_gpu():
torch.backends.cudnn.benchmark = True

logging_utils.log_versions()
logging.info(f"Current device is {fme.get_device()}")

stepper_config = config.load_stepper_config()
data_requirements = stepper_config.get_forcing_window_data_requirements(
n_coupled_steps=config.coupled_steps_in_memory
Expand Down
Loading