From 69859a1324c0d541b5713f0f33b935c33343df82 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 2 Feb 2026 16:10:50 +0000 Subject: [PATCH 1/7] reduce duplicated logging code, standardize logs --- fme/ace/data_loading/benchmark.py | 8 ++-- fme/ace/inference/evaluator.py | 16 +------ fme/ace/inference/inference.py | 16 +------ fme/ace/train/train.py | 17 +++----- fme/core/logging_utils.py | 51 ++++++++++++++++++---- fme/coupled/inference/evaluator.py | 16 +------ fme/coupled/inference/inference.py | 16 +------ fme/coupled/train/train.py | 15 +++---- fme/diffusion/train.py | 15 +++---- fme/downscaling/evaluator.py | 12 +---- fme/downscaling/inference/inference.py | 17 +++----- fme/downscaling/predict.py | 12 +---- fme/downscaling/train.py | 12 +---- scripts/monthly_data/write_monthly_data.py | 8 ++-- 14 files changed, 88 insertions(+), 143 deletions(-) diff --git a/fme/ace/data_loading/benchmark.py b/fme/ace/data_loading/benchmark.py index 5e982521c..952f33ec7 100644 --- a/fme/ace/data_loading/benchmark.py +++ b/fme/ace/data_loading/benchmark.py @@ -51,14 +51,14 @@ def configure_wandb(self, env_vars: dict | None = None, **kwargs): self.logging.configure_wandb(config=config, env_vars=env_vars, **kwargs) def configure_logging(self): - self.logging.configure_logging("/tmp", "log.txt") + config = to_flat_dict(dataclasses.asdict(self)) + self.logging.configure_logging( + "/tmp", "log.txt", config=config, resumable=False + ) def benchmark(config: BenchmarkConfig): config.configure_logging() - env_vars = logging_utils.retrieve_env_vars() - beaker_url = logging_utils.log_beaker_url() - config.configure_wandb(env_vars=env_vars, notes=beaker_url) wandb = WandB.get_instance() with GlobalTimer(): diff --git a/fme/ace/inference/evaluator.py b/fme/ace/inference/evaluator.py index 79c60e239..2c20dc013 100755 --- a/fme/ace/inference/evaluator.py +++ b/fme/ace/inference/evaluator.py @@ -8,7 +8,6 @@ import torch import fme -import fme.core.logging_utils as logging_utils from fme.ace.aggregator.inference import InferenceEvaluatorAggregatorConfig from fme.ace.data_loading.batch_data import BatchData from fme.ace.data_loading.getters import get_inference_data @@ -162,14 +161,9 @@ def __post_init__(self): ) def configure_logging(self, log_filename: str): - self.logging.configure_logging(self.experiment_dir, log_filename) - - def configure_wandb( - self, env_vars: dict | None = None, resumable: bool = False, **kwargs - ): config = to_flat_dict(dataclasses.asdict(self)) - self.logging.configure_wandb( - config=config, env_vars=env_vars, resumable=resumable, **kwargs + self.logging.configure_logging( + self.experiment_dir, log_filename, config=config, resumable=False ) def load_stepper(self) -> Stepper: @@ -246,16 +240,10 @@ def run_evaluator_from_config(config: InferenceEvaluatorConfig): makedirs(config.experiment_dir, exist_ok=True) config.configure_logging(log_filename="inference_out.log") - env_vars = logging_utils.retrieve_env_vars() - beaker_url = logging_utils.log_beaker_url() - config.configure_wandb(env_vars=env_vars, notes=beaker_url) if fme.using_gpu(): torch.backends.cudnn.benchmark = True - logging_utils.log_versions() - logging.info(f"Current device is {fme.get_device()}") - stepper_config = config.load_stepper_config() logging.info("Initializing data loader") window_requirements = stepper_config.get_evaluation_window_data_requirements( diff --git a/fme/ace/inference/inference.py b/fme/ace/inference/inference.py index c92ef29c4..534c59d4c 100644 --- a/fme/ace/inference/inference.py +++ b/fme/ace/inference/inference.py @@ -13,7 +13,6 @@ from xarray.coding.times import CFDatetimeCoder import fme -import fme.core.logging_utils as logging_utils from fme.ace.aggregator.inference import InferenceAggregatorConfig from fme.ace.data_loading.batch_data import BatchData, PrognosticState from fme.ace.data_loading.getters import get_forcing_data @@ -235,14 +234,9 @@ def __post_init__(self): ) def configure_logging(self, log_filename: str): - self.logging.configure_logging(self.experiment_dir, log_filename) - - def configure_wandb( - self, env_vars: dict | None = None, resumable: bool = False, **kwargs - ): config = to_flat_dict(dataclasses.asdict(self)) - self.logging.configure_wandb( - config=config, env_vars=env_vars, resumable=resumable, **kwargs + self.logging.configure_logging( + self.experiment_dir, log_filename, config=config, resumable=False ) def load_stepper(self) -> Stepper: @@ -299,16 +293,10 @@ def run_inference_from_config(config: InferenceConfig): with timer.context("initialization"): makedirs(config.experiment_dir, exist_ok=True) config.configure_logging(log_filename="inference_out.log") - env_vars = logging_utils.retrieve_env_vars() - beaker_url = logging_utils.log_beaker_url() - config.configure_wandb(env_vars=env_vars, notes=beaker_url) if fme.using_gpu(): torch.backends.cudnn.benchmark = True - logging_utils.log_versions() - logging.info(f"Current device is {fme.get_device()}") - stepper_config = config.load_stepper_config() data_requirements = stepper_config.get_forcing_window_data_requirements( n_forward_steps=config.forward_steps_in_memory diff --git a/fme/ace/train/train.py b/fme/ace/train/train.py index a04272e1b..b071d141b 100644 --- a/fme/ace/train/train.py +++ b/fme/ace/train/train.py @@ -58,7 +58,6 @@ import xarray as xr import fme -import fme.core.logging_utils as logging_utils from fme.ace.aggregator import ( OneStepAggregator, OneStepAggregatorConfig, @@ -75,7 +74,6 @@ from fme.core.cli import prepare_config, prepare_directory from fme.core.dataset_info import DatasetInfo from fme.core.derived_variables import get_derived_variable_metadata -from fme.core.dicts import to_flat_dict from fme.core.distributed import Distributed from fme.core.generics.data import InferenceDataABC from fme.core.generics.trainer import ( @@ -254,15 +252,12 @@ def run_train(builders: TrainBuilders, config: TrainConfig): torch.backends.cudnn.benchmark = True if not os.path.isdir(config.experiment_dir): os.makedirs(config.experiment_dir, exist_ok=True) - config.logging.configure_logging(config.experiment_dir, log_filename="out.log") - env_vars = logging_utils.retrieve_env_vars() - logging_utils.log_versions() - beaker_url = logging_utils.log_beaker_url() - config_as_dict = to_flat_dict(dataclasses.asdict(config)) - config.logging.configure_wandb( - config=config_as_dict, - env_vars=env_vars, - notes=beaker_url, + config_data = dataclasses.asdict(config) + config.logging.configure_logging( + config.experiment_dir, + log_filename="out.log", + config=config_data, + resumable=True, ) if config.resume_results is not None: logging.info( diff --git a/fme/core/logging_utils.py b/fme/core/logging_utils.py index 30a28dfd1..a282f731a 100644 --- a/fme/core/logging_utils.py +++ b/fme/core/logging_utils.py @@ -6,6 +6,8 @@ from collections.abc import Mapping from typing import Any +from torch import get_device + from fme.core.cloud import is_local from fme.core.distributed import Distributed from fme.core.wandb import WandB @@ -54,7 +56,28 @@ class LoggingConfig: def __post_init__(self): self._dist = Distributed.get_instance() - def configure_logging(self, experiment_dir: str, log_filename: str): + def configure_logging( + self, + experiment_dir: str, + log_filename: str, + config: Mapping[str, Any], + resumable: bool = False, + ): + """ + Configure global logging settings, including WandB, and output + initial logs of the runtime environment. + """ + self._configure_logging_module(experiment_dir, log_filename) + log_versions() + log_beaker_url() + self.configure_wandb( + config=config, + resumable=resumable, + ) + log_versions() + logging.info(f"Current device is {get_device()}") + + def _configure_logging_module(self, experiment_dir: str, log_filename: str): """ Configure the global `logging` module based on this LoggingConfig. """ @@ -88,11 +111,11 @@ def configure_logging(self, experiment_dir: str, log_filename: str): def configure_wandb( self, config: Mapping[str, Any], - env_vars: Mapping[str, Any] | None = None, resumable: bool = True, resume: Any = None, **kwargs, ): + env_vars = retrieve_env_vars() if resume is not None: raise ValueError( "The 'resume' argument is no longer supported, " @@ -123,10 +146,24 @@ def configure_wandb( experiment_dir=experiment_dir, resumable=resumable, dir=wandb_dir, - **kwargs, + notes=_get_beaker_url(_get_beaker_id()), ) +def _get_beaker_id() -> str | None: + try: + return os.environ["BEAKER_EXPERIMENT_ID"] + except KeyError: + logging.warning("Beaker Experiment ID not found.") + return None + + +def _get_beaker_url(beaker_id: str | None) -> str: + if beaker_id is None: + return "No beaker URL." + return f"https://beaker.org/ex/{beaker_id}" + + def log_versions(): import torch @@ -158,13 +195,9 @@ def log_beaker_url(beaker_id=None): Returns the Beaker URL. """ if beaker_id is None: - try: - beaker_id = os.environ["BEAKER_EXPERIMENT_ID"] - except KeyError: - logging.warning("Beaker Experiment ID not found.") - return None + beaker_id = _get_beaker_id() - beaker_url = f"https://beaker.org/ex/{beaker_id}" + beaker_url = _get_beaker_url(beaker_id) logging.info(f"Beaker ID: {beaker_id}") logging.info(f"Beaker URL: {beaker_url}") return beaker_url diff --git a/fme/coupled/inference/evaluator.py b/fme/coupled/inference/evaluator.py index 591e202a5..48d5a0f49 100644 --- a/fme/coupled/inference/evaluator.py +++ b/fme/coupled/inference/evaluator.py @@ -7,7 +7,6 @@ import torch import fme -import fme.core.logging_utils as logging_utils from fme.ace.stepper import load_stepper as load_single_stepper from fme.ace.stepper import load_stepper_config as load_single_stepper_config from fme.core.cli import prepare_config, prepare_directory @@ -189,14 +188,9 @@ class InferenceEvaluatorConfig: ) def configure_logging(self, log_filename: str): - self.logging.configure_logging(self.experiment_dir, log_filename) - - def configure_wandb( - self, env_vars: dict | None = None, resumable: bool = False, **kwargs - ): config = to_flat_dict(dataclasses.asdict(self)) - self.logging.configure_wandb( - config=config, env_vars=env_vars, resumable=resumable, **kwargs + self.logging.configure_logging( + self.experiment_dir, log_filename, config=config, resumable=False ) def load_stepper(self) -> CoupledStepper: @@ -268,16 +262,10 @@ def run_evaluator_from_config(config: InferenceEvaluatorConfig): makedirs(config.experiment_dir, exist_ok=True) config.configure_logging(log_filename="inference_out.log") - env_vars = logging_utils.retrieve_env_vars() - beaker_url = logging_utils.log_beaker_url() - config.configure_wandb(env_vars=env_vars, notes=beaker_url) if fme.using_gpu(): torch.backends.cudnn.benchmark = True - logging_utils.log_versions() - logging.info(f"Current device is {fme.get_device()}") - stepper_config = config.load_stepper_config() logging.info("Loading inference data") window_requirements = stepper_config.get_evaluation_window_data_requirements( diff --git a/fme/coupled/inference/inference.py b/fme/coupled/inference/inference.py index e29ca520c..c94a7961f 100644 --- a/fme/coupled/inference/inference.py +++ b/fme/coupled/inference/inference.py @@ -8,7 +8,6 @@ import xarray as xr import fme -import fme.core.logging_utils as logging_utils from fme.ace.data_loading.inference import ( ExplicitIndices, InferenceInitialConditionIndices, @@ -143,14 +142,9 @@ class InferenceConfig: n_ensemble_per_ic: int = 1 def configure_logging(self, log_filename: str): - self.logging.configure_logging(self.experiment_dir, log_filename) - - def configure_wandb( - self, env_vars: dict | None = None, resumable: bool = False, **kwargs - ): config = to_flat_dict(dataclasses.asdict(self)) - self.logging.configure_wandb( - config=config, env_vars=env_vars, resumable=resumable, **kwargs + self.logging.configure_logging( + self.experiment_dir, log_filename, config=config, resumable=False ) def load_stepper(self) -> CoupledStepper: @@ -227,16 +221,10 @@ def run_inference_from_config(config: InferenceConfig): makedirs(config.experiment_dir, exist_ok=True) config.configure_logging(log_filename="inference_out.log") - env_vars = logging_utils.retrieve_env_vars() - beaker_url = logging_utils.log_beaker_url() - config.configure_wandb(env_vars=env_vars, notes=beaker_url) if fme.using_gpu(): torch.backends.cudnn.benchmark = True - logging_utils.log_versions() - logging.info(f"Current device is {fme.get_device()}") - stepper_config = config.load_stepper_config() data_requirements = stepper_config.get_forcing_window_data_requirements( n_coupled_steps=config.coupled_steps_in_memory diff --git a/fme/coupled/train/train.py b/fme/coupled/train/train.py index b10d3f9a8..6f9e439c2 100644 --- a/fme/coupled/train/train.py +++ b/fme/coupled/train/train.py @@ -8,10 +8,8 @@ import xarray as xr import fme -import fme.core.logging_utils as logging_utils from fme.core.cli import prepare_config, prepare_directory from fme.core.derived_variables import get_derived_variable_metadata -from fme.core.dicts import to_flat_dict from fme.core.distributed import Distributed from fme.core.generics.trainer import AggregatorBuilderABC, Trainer from fme.core.typing_ import TensorDict, TensorMapping @@ -144,13 +142,12 @@ def run_train(builders: TrainBuilders, config: TrainConfig): torch.backends.cudnn.benchmark = True if not os.path.isdir(config.experiment_dir): os.makedirs(config.experiment_dir, exist_ok=True) - config.logging.configure_logging(config.experiment_dir, log_filename="out.log") - env_vars = logging_utils.retrieve_env_vars() - logging_utils.log_versions() - beaker_url = logging_utils.log_beaker_url() - config_as_dict = to_flat_dict(dataclasses.asdict(config)) - config.logging.configure_wandb( - config=config_as_dict, env_vars=env_vars, notes=beaker_url + config_data = dataclasses.asdict(config) + config.logging.configure_logging( + config.experiment_dir, + log_filename="out.log", + config=config_data, + resumable=True, ) if config.resume_results is not None: logging.info( diff --git a/fme/diffusion/train.py b/fme/diffusion/train.py index a24b4b93b..13810cd70 100644 --- a/fme/diffusion/train.py +++ b/fme/diffusion/train.py @@ -58,7 +58,6 @@ import xarray as xr import fme -import fme.core.logging_utils as logging_utils from fme.ace.aggregator import OneStepAggregator, TrainAggregator from fme.ace.aggregator.inference.main import ( InferenceEvaluatorAggregator, @@ -69,7 +68,6 @@ from fme.core.cli import get_parser, prepare_config, prepare_directory from fme.core.dataset.data_typing import VariableMetadata from fme.core.dataset_info import DatasetInfo -from fme.core.dicts import to_flat_dict from fme.core.distributed import Distributed from fme.core.generics.trainer import AggregatorBuilderABC, TrainConfigProtocol, Trainer from fme.core.typing_ import TensorDict, TensorMapping @@ -204,13 +202,12 @@ def run_train(builders: TrainBuilders, config: TrainConfig): torch.backends.cudnn.benchmark = True if not os.path.isdir(config.experiment_dir): os.makedirs(config.experiment_dir, exist_ok=True) - config.logging.configure_logging(config.experiment_dir, log_filename="out.log") - env_vars = logging_utils.retrieve_env_vars() - logging_utils.log_versions() - beaker_url = logging_utils.log_beaker_url() - config_as_dict = to_flat_dict(dataclasses.asdict(config)) - config.logging.configure_wandb( - config=config_as_dict, env_vars=env_vars, resumable=True, notes=beaker_url + config_data = dataclasses.asdict(config) + config.logging.configure_logging( + config.experiment_dir, + log_filename="out.log", + config=config_data, + resumable=True, ) if config.resume_results is not None: logging.info( diff --git a/fme/downscaling/evaluator.py b/fme/downscaling/evaluator.py index 3e970d805..d2ab88c28 100644 --- a/fme/downscaling/evaluator.py +++ b/fme/downscaling/evaluator.py @@ -6,7 +6,6 @@ import torch import yaml -import fme.core.logging_utils as logging_utils from fme.core.cli import prepare_directory from fme.core.dicts import to_flat_dict from fme.core.distributed import Distributed @@ -196,13 +195,9 @@ class EvaluatorConfig: events: list[PairedEventConfig] | None = None def configure_logging(self, log_filename: str): - self.logging.configure_logging(self.experiment_dir, log_filename) - - def configure_wandb(self, resumable: bool = False, **kwargs): config = to_flat_dict(dataclasses.asdict(self)) - env_vars = logging_utils.retrieve_env_vars() - self.logging.configure_wandb( - config=config, env_vars=env_vars, resumable=resumable, **kwargs + self.logging.configure_logging( + self.experiment_dir, log_filename, config=config, resumable=True ) def _build_default_evaluator(self) -> Evaluator: @@ -289,9 +284,6 @@ def main(config_path: str): prepare_directory(evaluator_config.experiment_dir, config) evaluator_config.configure_logging(log_filename="out.log") - logging_utils.log_versions() - beaker_url = logging_utils.log_beaker_url() - evaluator_config.configure_wandb(resumable=True, notes=beaker_url) logging.info("Starting downscaling model evaluation") evaluators = evaluator_config.build() diff --git a/fme/downscaling/inference/inference.py b/fme/downscaling/inference/inference.py index cf1ce055c..cb1319d95 100644 --- a/fme/downscaling/inference/inference.py +++ b/fme/downscaling/inference/inference.py @@ -1,11 +1,11 @@ +import dataclasses import logging -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field import dacite import torch import yaml -from fme.core import logging_utils from fme.core.cli import prepare_directory from fme.core.dicts import to_flat_dict from fme.core.logging_utils import LoggingConfig @@ -226,13 +226,9 @@ class InferenceConfig: patch: PatchPredictionConfig = field(default_factory=PatchPredictionConfig) def configure_logging(self, log_filename: str): - self.logging.configure_logging(self.experiment_dir, log_filename) - - def configure_wandb(self, resumable: bool = False, **kwargs): - config = to_flat_dict(asdict(self)) - env_vars = logging_utils.retrieve_env_vars() - self.logging.configure_wandb( - config=config, env_vars=env_vars, resumable=resumable, **kwargs + config = to_flat_dict(dataclasses.asdict(self)) + self.logging.configure_logging( + self.experiment_dir, log_filename, config=config, resumable=True ) def build(self) -> Downscaler: @@ -261,9 +257,6 @@ def main(config_path: str): prepare_directory(generation_config.experiment_dir, config) generation_config.configure_logging(log_filename="out.log") - logging_utils.log_versions() - beaker_url = logging_utils.log_beaker_url() - generation_config.configure_wandb(resumable=True, notes=beaker_url) logging.info("Starting downscaling generation...") downscaler = generation_config.build() diff --git a/fme/downscaling/predict.py b/fme/downscaling/predict.py index 9daaa778c..9dc5d6891 100644 --- a/fme/downscaling/predict.py +++ b/fme/downscaling/predict.py @@ -8,7 +8,6 @@ import xarray as xr import yaml -import fme.core.logging_utils as logging_utils from fme.core.cli import prepare_directory from fme.core.coordinates import LatLonCoordinates from fme.core.dataset.time import TimeSlice @@ -295,13 +294,9 @@ class DownscalerConfig: """ def configure_logging(self, log_filename: str): - self.logging.configure_logging(self.experiment_dir, log_filename) - - def configure_wandb(self, resumable: bool = False, **kwargs): config = to_flat_dict(dataclasses.asdict(self)) - env_vars = logging_utils.retrieve_env_vars() - self.logging.configure_wandb( - config=config, env_vars=env_vars, resumable=resumable, **kwargs + self.logging.configure_logging( + self.experiment_dir, log_filename, config=config, resumable=True ) def build(self) -> list[Downscaler | EventDownscaler]: @@ -350,9 +345,6 @@ def main(config_path: str): prepare_directory(downscaler_config.experiment_dir, config) downscaler_config.configure_logging(log_filename="out.log") - logging_utils.log_versions() - beaker_url = logging_utils.log_beaker_url() - downscaler_config.configure_wandb(resumable=True, notes=beaker_url) logging.info("Starting downscaling model generation...") downscalers = downscaler_config.build() diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index d155adf3a..b0a1c1560 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -11,7 +11,6 @@ import torch import yaml -import fme.core.logging_utils as logging_utils from fme.core.cli import prepare_directory from fme.core.dataset.xarray import get_raw_paths from fme.core.device import get_device @@ -476,13 +475,9 @@ def build(self) -> Trainer: ) def configure_logging(self, log_filename: str): - self.logging.configure_logging(self.experiment_dir, log_filename) - - def configure_wandb(self, resumable: bool = True, **kwargs): config = to_flat_dict(dataclasses.asdict(self)) - env_vars = logging_utils.retrieve_env_vars() - self.logging.configure_wandb( - config=config, env_vars=env_vars, resumable=resumable, **kwargs + self.logging.configure_logging( + self.experiment_dir, log_filename, config=config, resumable=True ) @@ -529,9 +524,6 @@ def main(config_path: str): prepare_directory(train_config.experiment_dir, config) train_config.configure_logging(log_filename="out.log") - logging_utils.log_versions() - beaker_url = logging_utils.log_beaker_url() - train_config.configure_wandb(notes=beaker_url) logging.info("Starting training") trainer = train_config.build() diff --git a/scripts/monthly_data/write_monthly_data.py b/scripts/monthly_data/write_monthly_data.py index e03131136..7c889a121 100644 --- a/scripts/monthly_data/write_monthly_data.py +++ b/scripts/monthly_data/write_monthly_data.py @@ -9,7 +9,6 @@ import xarray as xr import yaml -import fme.core.logging_utils as logging_utils from fme.ace.data_loading.batch_data import BatchData, default_collate from fme.ace.data_loading.config import DataLoaderConfig from fme.ace.inference.data_writer.dataset_metadata import DatasetMetadata @@ -25,6 +24,7 @@ from fme.core.dataset.properties import DatasetProperties from fme.core.dataset.xarray import get_xarray_datasets from fme.core.device import using_gpu +from fme.core.dicts import to_flat_dict from fme.core.distributed import Distributed from fme.core.logging_utils import LoggingConfig @@ -136,7 +136,10 @@ def get_data(self) -> "Data": ) def configure_logging(self, log_filename: str): - self.logging.configure_logging(self.experiment_dir, log_filename) + config = to_flat_dict(dataclasses.asdict(self)) + self.logging.configure_logging( + self.experiment_dir, log_filename, config=config, resumable=False + ) def get_data_writer(self, data: "Data") -> MonthlyDataWriter: assert data.properties.timestep is not None @@ -179,7 +182,6 @@ def merge_loaders(loaders: List[torch.utils.data.DataLoader]): def run(config: Config): config.configure_logging(log_filename="write_monthly_data_out.log") - logging_utils.log_versions() data = config.get_data() writer = config.get_data_writer(data) From e68a1ed8824963a7c53cd5cc50bea984b6a5a11d Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 2 Feb 2026 16:14:30 +0000 Subject: [PATCH 2/7] only log versions once --- fme/core/logging_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fme/core/logging_utils.py b/fme/core/logging_utils.py index a282f731a..1f904a08c 100644 --- a/fme/core/logging_utils.py +++ b/fme/core/logging_utils.py @@ -74,7 +74,6 @@ def configure_logging( config=config, resumable=resumable, ) - log_versions() logging.info(f"Current device is {get_device()}") def _configure_logging_module(self, experiment_dir: str, log_filename: str): From e98f2b1cadfaa06b300e39a8e7fd3439849beb5a Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 2 Feb 2026 16:15:40 +0000 Subject: [PATCH 3/7] add docstring --- fme/core/logging_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fme/core/logging_utils.py b/fme/core/logging_utils.py index 1f904a08c..71012c293 100644 --- a/fme/core/logging_utils.py +++ b/fme/core/logging_utils.py @@ -66,6 +66,12 @@ def configure_logging( """ Configure global logging settings, including WandB, and output initial logs of the runtime environment. + + Args: + experiment_dir: Directory to save logs to. + log_filename: Name of the log file. + config: Configuration dictionary to log to WandB. + resumable: Whether this is a resumable run. """ self._configure_logging_module(experiment_dir, log_filename) log_versions() From fa93fe7fc5a465d61460814c9e1563a86feeb71c Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 2 Feb 2026 17:00:16 +0000 Subject: [PATCH 4/7] import correct get_device --- fme/core/logging_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fme/core/logging_utils.py b/fme/core/logging_utils.py index 71012c293..6f53235f7 100644 --- a/fme/core/logging_utils.py +++ b/fme/core/logging_utils.py @@ -6,9 +6,8 @@ from collections.abc import Mapping from typing import Any -from torch import get_device - from fme.core.cloud import is_local +from fme.core.device import get_device from fme.core.distributed import Distributed from fme.core.wandb import WandB From 41cbdce942b964c90a46baab82a7841e71b3a704 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 2 Feb 2026 17:04:57 +0000 Subject: [PATCH 5/7] avoid depending on config api in wandb setup method --- fme/ace/data_loading/benchmark.py | 8 -------- fme/core/generics/test_trainer.py | 12 +++++++++--- fme/core/logging_utils.py | 7 ++++--- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/fme/ace/data_loading/benchmark.py b/fme/ace/data_loading/benchmark.py index 952f33ec7..76c65f796 100644 --- a/fme/ace/data_loading/benchmark.py +++ b/fme/ace/data_loading/benchmark.py @@ -5,7 +5,6 @@ import argparse import dataclasses import logging -import os import shutil import time import uuid @@ -43,13 +42,6 @@ def build(self): requirements=DataRequirements(self.names, self.n_timesteps), ) - def configure_wandb(self, env_vars: dict | None = None, **kwargs): - config = to_flat_dict(dataclasses.asdict(self)) - # our wandb class requires "experiment_dir" to be in config - config["experiment_dir"] = TMPDIR - os.makedirs(TMPDIR) - self.logging.configure_wandb(config=config, env_vars=env_vars, **kwargs) - def configure_logging(self): config = to_flat_dict(dataclasses.asdict(self)) self.logging.configure_logging( diff --git a/fme/core/generics/test_trainer.py b/fme/core/generics/test_trainer.py index 7a4cbde2a..4fbc31960 100644 --- a/fme/core/generics/test_trainer.py +++ b/fme/core/generics/test_trainer.py @@ -922,7 +922,9 @@ def _get_trainer(train_losses, val_losses, inference_errors): return trainer with mock_wandb() as wandb: - LoggingConfig(log_to_wandb=True).configure_wandb({"experiment_dir": tmp_path}) + LoggingConfig(log_to_wandb=True)._configure_wandb( + experiment_dir=tmp_path, config={}, resumable=True + ) # run training in two segments to ensure coverage of check that extra validation # really only happens before any training is done. trainer = _get_trainer(train_losses[:1], val_losses[:2], inference_errors[:2]) @@ -1066,7 +1068,9 @@ def _get_trainer(train_losses, val_losses, inference_errors): return trainer with mock_wandb() as wandb: - LoggingConfig(log_to_wandb=True).configure_wandb({"experiment_dir": tmp_path}) + LoggingConfig(log_to_wandb=True)._configure_wandb( + experiment_dir=tmp_path, config={}, resumable=True + ) trainer = _get_trainer(train_losses, val_losses, inference_errors) trainer.train() wandb_logs = wandb.get_logs() @@ -1105,7 +1109,9 @@ def _get_trainer(train_losses, val_losses, inference_errors): return trainer with mock_wandb() as wandb: - LoggingConfig(log_to_wandb=True).configure_wandb({"experiment_dir": tmp_path}) + LoggingConfig(log_to_wandb=True)._configure_wandb( + experiment_dir=tmp_path, config={}, resumable=True + ) trainer = _get_trainer(train_losses, val_losses, inference_errors) trainer.train() wandb_logs = wandb.get_logs() diff --git a/fme/core/logging_utils.py b/fme/core/logging_utils.py index 6f53235f7..74e97e969 100644 --- a/fme/core/logging_utils.py +++ b/fme/core/logging_utils.py @@ -75,7 +75,8 @@ def configure_logging( self._configure_logging_module(experiment_dir, log_filename) log_versions() log_beaker_url() - self.configure_wandb( + self._configure_wandb( + experiment_dir=experiment_dir, config=config, resumable=resumable, ) @@ -112,8 +113,9 @@ def _configure_logging_module(self, experiment_dir: str, log_filename: str): fh.setFormatter(logging.Formatter(self.log_format)) logger.addHandler(fh) - def configure_wandb( + def _configure_wandb( self, + experiment_dir: str, config: Mapping[str, Any], resumable: bool = True, resume: Any = None, @@ -134,7 +136,6 @@ def configure_wandb( elif env_vars is not None: config_copy["environment"] = env_vars - experiment_dir = config["experiment_dir"] if self.wandb_dir_in_experiment_dir: wandb_dir = experiment_dir else: From 4bcc46f9bf40127b92358cb2966d231590e0be81 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 2 Feb 2026 17:09:33 +0000 Subject: [PATCH 6/7] consolidate temp dirs --- fme/ace/data_loading/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fme/ace/data_loading/benchmark.py b/fme/ace/data_loading/benchmark.py index 76c65f796..b85d41e55 100644 --- a/fme/ace/data_loading/benchmark.py +++ b/fme/ace/data_loading/benchmark.py @@ -45,7 +45,7 @@ def build(self): def configure_logging(self): config = to_flat_dict(dataclasses.asdict(self)) self.logging.configure_logging( - "/tmp", "log.txt", config=config, resumable=False + TMPDIR, "log.txt", config=config, resumable=False ) From 866f7e84b99d38bc20e77c976c799f6f4afe2346 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 2 Feb 2026 17:11:53 +0000 Subject: [PATCH 7/7] make the tmpdir in benchmark --- fme/ace/data_loading/benchmark.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fme/ace/data_loading/benchmark.py b/fme/ace/data_loading/benchmark.py index b85d41e55..50d77f588 100644 --- a/fme/ace/data_loading/benchmark.py +++ b/fme/ace/data_loading/benchmark.py @@ -5,6 +5,7 @@ import argparse import dataclasses import logging +import os import shutil import time import uuid @@ -44,6 +45,7 @@ def build(self): def configure_logging(self): config = to_flat_dict(dataclasses.asdict(self)) + os.makedirs(TMPDIR, exist_ok=True) self.logging.configure_logging( TMPDIR, "log.txt", config=config, resumable=False )