Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
9 changes: 7 additions & 2 deletions examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,25 @@
git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
sys.path.insert(1, git_repo_path)

# Add parent directory to path so we can import from tests
repo_root = abspath(dirname(dirname(__file__)))
if repo_root not in sys.path:
sys.path.insert(0, repo_root)


# silence FutureWarning warnings in tests since often we can't act on them until
# they become normal warnings - i.e. the tests still need to test the current functionality
warnings.simplefilter(action="ignore", category=FutureWarning)


def pytest_addoption(parser):
from diffusers.utils.testing_utils import pytest_addoption_shared
from tests.testing_utils import pytest_addoption_shared

pytest_addoption_shared(parser)


def pytest_terminal_summary(terminalreporter):
from diffusers.utils.testing_utils import pytest_terminal_summary_main
from tests.testing_utils import pytest_terminal_summary_main

make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports:
Expand Down
5 changes: 3 additions & 2 deletions examples/controlnet/train_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import os
import random
import shutil

# Add repo root to path to import from tests
from pathlib import Path

import accelerate
Expand Down Expand Up @@ -54,8 +56,7 @@
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.testing_utils import backend_empty_cache
from diffusers.utils.torch_utils import is_compiled_module
from diffusers.utils.torch_utils import backend_empty_cache, is_compiled_module


if is_wandb_available():
Expand Down
8 changes: 7 additions & 1 deletion examples/vqgan/test_vqgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@
import torch

from diffusers import VQModel
from diffusers.utils.testing_utils import require_timm


# Add parent directories to path to import from tests
sys.path.append("..")
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
if repo_root not in sys.path:
sys.path.insert(0, repo_root)

from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402

from tests.testing_utils import require_timm # noqa


logging.basicConfig(level=logging.DEBUG)

Expand Down
102 changes: 73 additions & 29 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@
global_rng = random.Random()

logger = get_logger(__name__)

logger.warning(
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
)
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5")
Expand Down Expand Up @@ -801,10 +804,9 @@ def export_to_ply(mesh, output_ply_path: str = None):
f.write(format.pack(*vertex))

if faces is not None:
format = struct.Struct("<B3I")
for tri in faces.tolist():
f.write(format.pack(len(tri), *tri))

format = struct.Struct("<B3I")
return output_ply_path


Expand Down Expand Up @@ -1144,23 +1146,23 @@ def enable_full_determinism():
Helper function for reproducible behavior during distributed training. See
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
"""
# Enable PyTorch deterministic mode. This potentially requires either the environment
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
# depending on the CUDA version, so we set them both here
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)
from .torch_utils import enable_full_determinism as _enable_full_determinism

# Enable CUDNN deterministic mode
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
logger.warning(
"enable_full_determinism has been moved to diffusers.utils.torch_utils. "
"Importing from diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _enable_full_determinism()


def disable_full_determinism():
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
torch.use_deterministic_algorithms(False)
from .torch_utils import disable_full_determinism as _disable_full_determinism

logger.warning(
"disable_full_determinism has been moved to diffusers.utils.torch_utils. "
"Importing from diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _disable_full_determinism()


# Utils for custom and alternative accelerator devices
Expand Down Expand Up @@ -1282,43 +1284,85 @@ def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable],

# These are callables which automatically dispatch the function specific to the accelerator
def backend_manual_seed(device: str, seed: int):
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
from .torch_utils import backend_manual_seed as _backend_manual_seed

logger.warning(
"backend_manual_seed has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_manual_seed(device, seed)


def backend_synchronize(device: str):
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
from .torch_utils import backend_synchronize as _backend_synchronize

logger.warning(
"backend_synchronize has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_synchronize(device)


def backend_empty_cache(device: str):
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
from .torch_utils import backend_empty_cache as _backend_empty_cache

logger.warning(
"backend_empty_cache has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_empty_cache(device)


def backend_device_count(device: str):
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
from .torch_utils import backend_device_count as _backend_device_count

logger.warning(
"backend_device_count has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_device_count(device)


def backend_reset_peak_memory_stats(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
from .torch_utils import backend_reset_peak_memory_stats as _backend_reset_peak_memory_stats

logger.warning(
"backend_reset_peak_memory_stats has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_reset_peak_memory_stats(device)


def backend_reset_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
from .torch_utils import backend_reset_max_memory_allocated as _backend_reset_max_memory_allocated

logger.warning(
"backend_reset_max_memory_allocated has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_reset_max_memory_allocated(device)


def backend_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
from .torch_utils import backend_max_memory_allocated as _backend_max_memory_allocated

logger.warning(
"backend_max_memory_allocated has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_max_memory_allocated(device)


# These are callables which return boolean behaviour flags and can be used to specify some
# device agnostic alternative where the feature is unsupported.
def backend_supports_training(device: str):
if not is_torch_available():
return False

if device not in BACKEND_SUPPORTS_TRAINING:
device = "default"
from .torch_utils import backend_supports_training as _backend_supports_training

return BACKEND_SUPPORTS_TRAINING[device]
logger.warning(
"backend_supports_training has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_supports_training(device)


# Guard for when Torch is not available
Expand Down
137 changes: 136 additions & 1 deletion src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"""

import functools
from typing import List, Optional, Tuple, Union
import os
from typing import Callable, Dict, List, Optional, Tuple, Union

from . import logging
from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version
Expand All @@ -26,6 +27,56 @@
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift

BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
BACKEND_EMPTY_CACHE = {
"cuda": torch.cuda.empty_cache,
"xpu": torch.xpu.empty_cache,
"cpu": None,
"mps": torch.mps.empty_cache,
"default": None,
}
BACKEND_DEVICE_COUNT = {
"cuda": torch.cuda.device_count,
"xpu": torch.xpu.device_count,
"cpu": lambda: 0,
"mps": lambda: 0,
"default": 0,
}
BACKEND_MANUAL_SEED = {
"cuda": torch.cuda.manual_seed,
"xpu": torch.xpu.manual_seed,
"cpu": torch.manual_seed,
"mps": torch.mps.manual_seed,
"default": torch.manual_seed,
}
BACKEND_RESET_PEAK_MEMORY_STATS = {
"cuda": torch.cuda.reset_peak_memory_stats,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"default": None,
}
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.reset_max_memory_allocated,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"default": None,
}
BACKEND_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.max_memory_allocated,
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
"cpu": 0,
"mps": 0,
"default": 0,
}
BACKEND_SYNCHRONIZE = {
"cuda": torch.cuda.synchronize,
"xpu": getattr(torch.xpu, "synchronize", None),
"cpu": None,
"mps": None,
"default": None,
}
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

try:
Expand All @@ -36,6 +87,62 @@ def maybe_allow_in_graph(cls):
return cls


# This dispatches a defined function according to the accelerator from the function definitions.
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
if device not in dispatch_table:
return dispatch_table["default"](*args, **kwargs)

fn = dispatch_table[device]

# Some device agnostic functions return values. Need to guard against 'None' instead at
# user level
if not callable(fn):
return fn

return fn(*args, **kwargs)


# These are callables which automatically dispatch the function specific to the accelerator
def backend_manual_seed(device: str, seed: int):
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)


def backend_synchronize(device: str):
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)


def backend_empty_cache(device: str):
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)


def backend_device_count(device: str):
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)


def backend_reset_peak_memory_stats(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)


def backend_reset_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)


def backend_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)


# These are callables which return boolean behaviour flags and can be used to specify some
# device agnostic alternative where the feature is unsupported.
def backend_supports_training(device: str):
if not is_torch_available():
return False

if device not in BACKEND_SUPPORTS_TRAINING:
device = "default"

return BACKEND_SUPPORTS_TRAINING[device]


def randn_tensor(
shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
Expand Down Expand Up @@ -197,3 +304,31 @@ def device_synchronize(device_type: Optional[str] = None):
device_type = get_device()
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.synchronize()


def enable_full_determinism():
"""
Helper function for reproducible behavior during distributed training. See
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
"""
# Enable PyTorch deterministic mode. This potentially requires either the environment
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
# depending on the CUDA version, so we set them both here
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)

# Enable CUDNN deterministic mode
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False


def disable_full_determinism():
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
torch.use_deterministic_algorithms(False)


if is_torch_available():
torch_device = get_device()
Loading
Loading