diff --git a/icu_benchmarks/models/train.py b/icu_benchmarks/models/train.py index 1e5360b1..d7fe83cc 100644 --- a/icu_benchmarks/models/train.py +++ b/icu_benchmarks/models/train.py @@ -140,8 +140,13 @@ def train_common( model.set_weight(weight, train_dataset) model.set_trained_columns(train_dataset.get_feature_names()) loggers = [TensorBoardLogger(log_dir), JSONMetricsLogger(log_dir)] + devices = max(torch.cuda.device_count(), 1) + if use_wandb: loggers.append(WandbLogger(save_dir=log_dir)) + logging.info("Use of wandb is detected. Only single gpu training is supported with wandb.") + devices = 1 + callbacks = [ EarlyStopping(monitor="val/loss", min_delta=min_delta, patience=patience, strict=False, verbose=verbose), ModelCheckpoint(log_dir, filename="model", save_top_k=1, save_last=True), @@ -158,7 +163,7 @@ def train_common( callbacks=callbacks, precision=precision, accelerator="auto" if not cpu else "cpu", - devices=max(torch.cuda.device_count(), 1), + devices=devices, deterministic="warn" if reproducible else False, benchmark=not reproducible, enable_progress_bar=verbose,