Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def reset(
assert not (random and sink)

if random and seed is not None:
set_seed(seed, performance_mode=True)
set_seed(seed, deterministic_mode=False) # TODO: configurable?

if isinstance(batch_shape, int):
batch_shape = (batch_shape,)
Expand Down
35 changes: 20 additions & 15 deletions src/gfn/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ def filter_kwargs_for_callable(
# -----------------------------------------------------------------------------


def set_seed(seed: int, performance_mode: bool = False) -> None:
def set_seed(seed: int, deterministic_mode: bool = False) -> None:
"""Used to control randomness for both single and distributed training.

Args:
seed: The seed to use for all random number generators
performance_mode: If True, disables deterministic behavior for better performance.
deterministic_mode: If True, uses deterministic behavior for better performance.
In multi-GPU settings, this only affects cuDNN. In multi-CPU settings,
this allows parallel processing in NumPy.
"""
Expand Down Expand Up @@ -186,19 +186,22 @@ def set_seed(seed: int, performance_mode: bool = False) -> None:

# Set device-specific environment variables
if torch.cuda.is_available():
# For GPU training, we can use multiple threads for CPU operations
if performance_mode:
os.environ["OMP_NUM_THREADS"] = str(num_cpus)
os.environ["MKL_NUM_THREADS"] = str(num_cpus)
else:
if deterministic_mode:
# For reproducibility in GPU training, we still want deterministic
# CPU operations
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
else:
# For GPU training, we can use multiple threads for CPU operations
os.environ["OMP_NUM_THREADS"] = str(num_cpus)
os.environ["MKL_NUM_THREADS"] = str(num_cpus)
Comment on lines +194 to +197
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this the default behavior? If yes, I think we can remove this else branch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what is default or not, so I'd like to keep this in for now. It was really hard to nail down the cause of the irreproducibility errors we were seeing before in distributed training.

else:
# For CPU-only training, we need to be more careful with threading
if performance_mode:

if deterministic_mode:
# For perfect reproducibility in CPU training, disable parallel processing
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
else:
# Allow parallel processing but with controlled number of threads
# Different backends might handle threading differently
if backend in ["mpi", "ccl"]:
Expand All @@ -211,10 +214,6 @@ def set_seed(seed: int, performance_mode: bool = False) -> None:

os.environ["OMP_NUM_THREADS"] = str(num_threads)
os.environ["MKL_NUM_THREADS"] = str(num_threads)
else:
# For perfect reproducibility in CPU training, disable parallel processing
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

else:
# Non-distributed training - use the global seed
Expand All @@ -237,7 +236,7 @@ def set_seed(seed: int, performance_mode: bool = False) -> None:
threading.current_thread()._seed = seed

# These are only set when we care about reproducibility over performance
if not performance_mode:
if deterministic_mode:
# GPU-specific settings
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
Expand Down Expand Up @@ -314,9 +313,15 @@ def temporarily_set_seed(seed):

def make_dataloader_seed_fns(
base_seed: int,
deterministic_mode: bool = False,
) -> Tuple[Callable[[int], None], torch.Generator]:
"""Return `(worker_init_fn, generator)` for DataLoader reproducibility.

Args:
base_seed: The base seed to use for the DataLoader.
deterministic_mode: If True, uses deterministic behavior for better
reproducibility at the cost of performance.

Example
-------
>>> w_init, g = make_dataloader_seed_fns(process_seed)
Expand All @@ -332,7 +337,7 @@ def make_dataloader_seed_fns(

def _worker_init_fn(worker_id: int) -> None: # pragma: no cover
# Each worker gets a distinct seed in the same pattern used for ranks.
set_seed(base_seed + worker_id, performance_mode=False)
set_seed(base_seed + worker_id, deterministic_mode=deterministic_mode)

gen = torch.Generator()
gen.manual_seed(base_seed)
Expand Down
2 changes: 1 addition & 1 deletion testing/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def test_uniform_log_probs_method():
def test_mix_with_uniform_in_log_space():
"""Test the _mix_with_uniform_in_log_space static method."""
batch_size, n_actions = 3, 4
set_seed(123)
set_seed(123, deterministic_mode=True)

# Create log-softmax values
logits = torch.randn(batch_size, n_actions)
Expand Down