diff --git a/pyproject.toml b/pyproject.toml index dec4906e..e6eb7541 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arc-state" -version = "0.9.31" +version = "0.9.32" description = "State is a machine learning model that predicts cellular perturbation response across diverse contexts." readme = "README.md" authors = [ diff --git a/src/state/configs/wandb/default.yaml b/src/state/configs/wandb/default.yaml index 0d13940a..a773a9a9 100644 --- a/src/state/configs/wandb/default.yaml +++ b/src/state/configs/wandb/default.yaml @@ -3,4 +3,5 @@ entity: your_entity_name project: state local_wandb_dir: ./wandb_logs -tags: [] +tags: [] +group: null diff --git a/src/state/tx/utils/__init__.py b/src/state/tx/utils/__init__.py index 7a35c853..5d426414 100644 --- a/src/state/tx/utils/__init__.py +++ b/src/state/tx/utils/__init__.py @@ -1,10 +1,10 @@ import time import logging +import os from contextlib import contextmanager from lightning.pytorch.loggers import CSVLogger, WandbLogger from lightning.pytorch.loggers.csv_logs import CSVLogger as BaseCSVLogger import csv -import os from lightning.pytorch.callbacks import ModelCheckpoint from os.path import join @@ -100,13 +100,14 @@ def get_loggers( try: # Check if wandb is available import wandb - + wandb_logger = WandbLogger( name=name, project=wandb_project, entity=wandb_entity, dir=local_wandb_dir, tags=cfg["wandb"].get("tags", []) if cfg else [], + group=cfg["wandb"].get("group", None) if cfg else None, ) if cfg is not None: wandb_logger.experiment.config.update(cfg)